diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml
index eefb1ebbb9..c55b013dbe 100644
--- a/.github/workflows/pyrefly-diff-comment.yml
+++ b/.github/workflows/pyrefly-diff-comment.yml
@@ -76,13 +76,11 @@ jobs:
diff += '\\n\\n... (truncated) ...';
}
- const body = diff.trim()
- ? '### Pyrefly Diff\n\nbase → PR
\n\n```diff\n' + diff + '\n```\n '
- : '### Pyrefly Diff\nNo changes detected.';
-
- await github.rest.issues.createComment({
- issue_number: prNumber,
- owner: context.repo.owner,
- repo: context.repo.repo,
- body,
- });
+ if (diff.trim()) {
+ await github.rest.issues.createComment({
+ issue_number: prNumber,
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ body: '### Pyrefly Diff\n\nbase → PR
\n\n```diff\n' + diff + '\n```\n ',
+ });
+ }
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index f3ab4c62c7..2a5cf19645 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -89,3 +89,37 @@ jobs:
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
+
+ dify-ui-test:
+ name: dify-ui Tests
+ runs-on: ubuntu-latest
+ env:
+ CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
+ defaults:
+ run:
+ shell: bash
+ working-directory: ./packages/dify-ui
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
+ with:
+ persist-credentials: false
+
+ - name: Setup web environment
+ uses: ./.github/actions/setup-web
+
+ - name: Install Chromium for Browser Mode
+ run: vp exec playwright install --with-deps chromium
+
+ - name: Run dify-ui tests
+ run: vp test run --coverage --silent=passed-only
+
+ - name: Report coverage
+ if: ${{ env.CODECOV_TOKEN != '' }}
+ uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
+ with:
+ directory: packages/dify-ui/coverage
+ flags: dify-ui
+ env:
+ CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
diff --git a/api/commands/account.py b/api/commands/account.py
index 6a2a2e0428..761323a73d 100644
--- a/api/commands/account.py
+++ b/api/commands/account.py
@@ -2,6 +2,7 @@ import base64
import secrets
import click
+from sqlalchemy.orm import Session
from constants.languages import languages
from extensions.ext_database import db
@@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm):
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
- account = db.session.merge(account)
- account.password = base64_password_hashed
- account.password_salt = base64_salt
- db.session.commit()
+ with Session(db.engine) as session:
+ account = session.merge(account)
+ account.password = base64_password_hashed
+ account.password_salt = base64_salt
+ session.commit()
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
- account = db.session.merge(account)
- account.email = normalized_new_email
- db.session.commit()
+ with Session(db.engine) as session:
+ account = session.merge(account)
+ account.email = normalized_new_email
+ session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
diff --git a/api/constants/dsl_version.py b/api/constants/dsl_version.py
new file mode 100644
index 0000000000..b0fbe0075c
--- /dev/null
+++ b/api/constants/dsl_version.py
@@ -0,0 +1 @@
+CURRENT_APP_DSL_VERSION = "0.6.0"
diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py
index cead33d14f..9c8b095b9f 100644
--- a/api/controllers/console/app/conversation_variables.py
+++ b/api/controllers/console/app/conversation_variables.py
@@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
return value
try:
diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py
index f6319573e0..e32ba5f66c 100644
--- a/api/controllers/console/app/workflow_draft_variable.py
+++ b/api/controllers/console/app/workflow_draft_variable.py
@@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
- return value_type.exposed_type().value
+ return str(value_type.exposed_type())
class FullContentDict(TypedDict):
@@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
result: FullContentDict = {
"size_bytes": variable_file.size,
- "value_type": variable_file.value_type.exposed_type().value,
+ "value_type": str(variable_file.value_type.exposed_type()),
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
@@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
- "value_type": v.value_type.exposed_type().value,
+ "value_type": str(v.value_type.exposed_type()),
"value": v.value,
# Do not track edited for env vars.
"edited": False,
diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py
index a5846e2815..2f309262cb 100644
--- a/api/controllers/inner_api/plugin/wraps.py
+++ b/api/controllers/inner_api/plugin/wraps.py
@@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel):
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
- Get current user
+ Get current user.
NOTE: user_id is not trusted, it could be maliciously set to any value.
- As a result, it could only be considered as an end user id.
+ As a result, it could only be considered as an end user id. Even when a
+ concrete end-user ID is supplied, lookups must stay tenant-scoped so one
+ tenant cannot bind another tenant's user record into the plugin request
+ context.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
@@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
.limit(1)
)
else:
- user_model = session.get(EndUser, user_id)
+ user_model = session.scalar(
+ select(EndUser)
+ .where(
+ EndUser.id == user_id,
+ EndUser.tenant_id == tenant_id,
+ )
+ .limit(1)
+ )
if not user_model:
user_model = EndUser(
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index c4353ca7b8..ca4b18cb5e 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
def normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
try:
- return str(SegmentType(value).exposed_type().value)
+ return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index 790602ef5d..c22102c2ba 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index d38d24d1e7..29de0b8b1c 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -299,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# update prompt tool
for prompt_tool in prompt_messages_tools:
- self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
+ tool_instance = tool_instances.get(prompt_tool.name)
+ if tool_instance:
+ self.update_prompt_message_tool(tool_instance, prompt_tool)
iteration_step += 1
diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
index dbd7527fc6..5df3df2b3e 100644
--- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
+++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py
@@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class ModelConfigConverter:
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 09ddce327e..cae0eee0df 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
from extensions.ext_database import db
from graphon.model_runtime.entities.llm_entities import LLMMode
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import App, Conversation, Message
logger = logging.getLogger(__name__)
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index dfe6133cb6..e2e07ebaff 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
TextPromptMessageContent,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py
index 68e5e5f0c8..3a6f9d575a 100644
--- a/api/core/app/workflow/file_runtime.py
+++ b/api/core/app/workflow/file_runtime.py
@@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
from core.db.session_factory import session_factory
-from core.helper.ssrf_proxy import ssrf_proxy
+from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.tools.signature import sign_tool_file
from core.workflow.file_reference import parse_file_reference
from extensions.ext_storage import storage
from graphon.file import FileTransferMethod
-from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
+from graphon.file.protocols import WorkflowFileRuntimeProtocol
from graphon.file.runtime import set_workflow_file_runtime
+from graphon.http.protocols import HttpResponseProtocol
if TYPE_CHECKING:
from graphon.file import File
@@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
return dify_config.MULTIMODAL_SEND_FORMAT
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
- return ssrf_proxy.get(url, follow_redirects=follow_redirects)
+ return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py
index 87f005a250..d521304615 100644
--- a/api/core/app/workflow/layers/persistence.py
+++ b/api/core/app/workflow/layers/persistence.py
@@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
- execution.exceptions_count = runtime_state.exceptions_count
+ execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
def _update_node_execution(
self,
diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py
index dc831e5cac..f0dcb13b62 100644
--- a/api/core/datasource/datasource_manager.py
+++ b/api/core/datasource/datasource_manager.py
@@ -352,11 +352,11 @@ class DatasourceManager:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.CUSTOM,
+ file_type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(record_id=str(upload_file.id)),
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index 1ab66cceee..38b87e2cd1 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
FormType,
ProviderEntity,
)
-from graphon.model_runtime.model_providers.__base.ai_model import AIModel
+from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
@@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
else [],
)
- def validate_provider_credentials(
- self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
- ):
+ def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
"""
Validate custom credentials.
:param credentials: provider credentials
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
- :param session: optional database session
:return:
"""
+ provider_credential_secret_variables = self.extract_secret_variables(
+ self.provider.provider_credential_schema.credential_form_schemas
+ if self.provider.provider_credential_schema
+ else []
+ )
- def _validate(s: Session):
- # Get provider credential secret variables
- provider_credential_secret_variables = self.extract_secret_variables(
- self.provider.provider_credential_schema.credential_form_schemas
- if self.provider.provider_credential_schema
- else []
- )
-
- if credential_id:
+ if credential_id:
+ with Session(db.engine) as session:
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
- credential_record = s.execute(stmt).scalar_one_or_none()
- # fix origin data
+ credential_record = session.execute(stmt).scalar_one_or_none()
if credential_record and credential_record.encrypted_config:
if not credential_record.encrypted_config.startswith("{"):
original_credentials = {"openai_api_key": credential_record.encrypted_config}
@@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
- # encrypt credentials
- for key, value in credentials.items():
- if key in provider_credential_secret_variables:
- # if send [__HIDDEN__] in secret input, it will be same as original value
- if value == HIDDEN_VALUE and key in original_credentials:
- credentials[key] = encrypter.decrypt_token(
- tenant_id=self.tenant_id, token=original_credentials[key]
- )
-
- model_provider_factory = self.get_model_provider_factory()
- validated_credentials = model_provider_factory.provider_credentials_validate(
- provider=self.provider.provider, credentials=credentials
- )
-
- for key, value in validated_credentials.items():
+ for key, value in credentials.items():
if key in provider_credential_secret_variables:
- validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+ if value == HIDDEN_VALUE and key in original_credentials:
+ credentials[key] = encrypter.decrypt_token(
+ tenant_id=self.tenant_id, token=original_credentials[key]
+ )
- return validated_credentials
+ model_provider_factory = self.get_model_provider_factory()
+ validated_credentials = model_provider_factory.provider_credentials_validate(
+ provider=self.provider.provider, credentials=credentials
+ )
- if session:
- return _validate(session)
- else:
- with Session(db.engine) as new_session:
- return _validate(new_session)
+ for key, value in validated_credentials.items():
+ if key in provider_credential_secret_variables and isinstance(value, str):
+ validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+
+ return validated_credentials
def _generate_provider_credential_name(self, session) -> str:
"""
@@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
- with Session(db.engine) as session:
+ with Session(db.engine) as pre_session:
if credential_name:
- if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
+ if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
- credential_name = self._generate_provider_credential_name(session)
+ credential_name = self._generate_provider_credential_name(pre_session)
- credentials = self.validate_provider_credentials(credentials=credentials, session=session)
+ credentials = self.validate_provider_credentials(credentials=credentials)
+
+ with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
try:
new_record = ProviderCredential(
@@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
session.flush()
if not provider_record:
- # If provider record does not exist, create it
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
@@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
- with Session(db.engine) as session:
+ with Session(db.engine) as pre_session:
if credential_name and self._check_provider_credential_name_exists(
- credential_name=credential_name, session=session, exclude_id=credential_id
+ credential_name=credential_name, session=pre_session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
- credentials = self.validate_provider_credentials(
- credentials=credentials, credential_id=credential_id, session=session
- )
+ credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
+
+ with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
@@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
- # Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
- # Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
@@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
model: str,
credentials: dict[str, Any],
credential_id: str = "",
- session: Session | None = None,
):
"""
Validate custom model credentials.
@@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:return:
"""
+ provider_credential_secret_variables = self.extract_secret_variables(
+ self.provider.model_credential_schema.credential_form_schemas
+ if self.provider.model_credential_schema
+ else []
+ )
- def _validate(s: Session):
- # Get provider credential secret variables
- provider_credential_secret_variables = self.extract_secret_variables(
- self.provider.model_credential_schema.credential_form_schemas
- if self.provider.model_credential_schema
- else []
- )
-
- if credential_id:
+ if credential_id:
+ with Session(db.engine) as session:
try:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
@@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
)
- credential_record = s.execute(stmt).scalar_one_or_none()
+ credential_record = session.execute(stmt).scalar_one_or_none()
original_credentials = (
json.loads(credential_record.encrypted_config)
if credential_record and credential_record.encrypted_config
@@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
- # decrypt credentials
- for key, value in credentials.items():
- if key in provider_credential_secret_variables:
- # if send [__HIDDEN__] in secret input, it will be same as original value
- if value == HIDDEN_VALUE and key in original_credentials:
- credentials[key] = encrypter.decrypt_token(
- tenant_id=self.tenant_id, token=original_credentials[key]
- )
-
- model_provider_factory = self.get_model_provider_factory()
- validated_credentials = model_provider_factory.model_credentials_validate(
- provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
- )
-
- for key, value in validated_credentials.items():
+ for key, value in credentials.items():
if key in provider_credential_secret_variables:
- validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+ if value == HIDDEN_VALUE and key in original_credentials:
+ credentials[key] = encrypter.decrypt_token(
+ tenant_id=self.tenant_id, token=original_credentials[key]
+ )
- return validated_credentials
+ model_provider_factory = self.get_model_provider_factory()
+ validated_credentials = model_provider_factory.model_credentials_validate(
+ provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
+ )
- if session:
- return _validate(session)
- else:
- with Session(db.engine) as new_session:
- return _validate(new_session)
+ for key, value in validated_credentials.items():
+ if key in provider_credential_secret_variables and isinstance(value, str):
+ validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
+
+ return validated_credentials
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
@@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
:param credentials: model credentials dict
:return:
"""
- with Session(db.engine) as session:
+ with Session(db.engine) as pre_session:
if credential_name:
if self._check_custom_model_credential_name_exists(
- model=model, model_type=model_type, credential_name=credential_name, session=session
+ model=model, model_type=model_type, credential_name=credential_name, session=pre_session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
- model=model, model_type=model_type, session=session
+ model=model, model_type=model_type, session=pre_session
)
- # validate custom model config
- credentials = self.validate_custom_model_credentials(
- model_type=model_type, model=model, credentials=credentials, session=session
- )
+
+ credentials = self.validate_custom_model_credentials(
+ model_type=model_type, model=model, credentials=credentials
+ )
+
+ with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
try:
@@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
session.add(credential)
session.flush()
- # save provider model
if not provider_model_record:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
@@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
:param credential_id: credential id
:return:
"""
- with Session(db.engine) as session:
+ with Session(db.engine) as pre_session:
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
- session=session,
+ session=pre_session,
exclude_id=credential_id,
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
- # validate custom model config
- credentials = self.validate_custom_model_credentials(
- model_type=model_type,
- model=model,
- credentials=credentials,
- credential_id=credential_id,
- session=session,
- )
+
+ credentials = self.validate_custom_model_credentials(
+ model_type=model_type,
+ model=model,
+ credentials=credentials,
+ credential_id=credential_id,
+ )
+
+ with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
@@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Credential record not found.")
try:
- # Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index b96a9ce380..38864a1830 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
- inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
+ inputs_json_str = dumps_with_segments(inputs).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py
index dc37a36943..f169f247cf 100644
--- a/api/core/helper/moderation.py
+++ b/api/core/helper/moderation.py
@@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py
index e38592bb7b..91e92712b7 100644
--- a/api/core/helper/ssrf_proxy.py
+++ b/api/core/helper/ssrf_proxy.py
@@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.tools.errors import ToolSSRFError
+from graphon.http.response import HttpResponse
logger = logging.getLogger(__name__)
@@ -267,4 +268,47 @@ class SSRFProxy:
return patch(url=url, max_retries=max_retries, **kwargs)
+def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
+ """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
+ return HttpResponse(
+ status_code=response.status_code,
+ headers=dict(response.headers),
+ content=response.content,
+ url=str(response.url) if response.url else None,
+ reason_phrase=response.reason_phrase,
+ fallback_text=response.text,
+ )
+
+
+class GraphonSSRFProxy:
+ """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
+
+ @property
+ def max_retries_exceeded_error(self) -> type[Exception]:
+ return max_retries_exceeded_error
+
+ @property
+ def request_error(self) -> type[Exception]:
+ return request_error
+
+ def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
+
+ def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
+
+ def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
+
+ def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
+
+ def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
+
+ def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
+ return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
+
+
ssrf_proxy = SSRFProxy()
+graphon_ssrf_proxy = GraphonSSRFProxy()
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index d8d8dfedd8..86d0e3baaa 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
-from typing import IO, Any, Literal, Optional, Union, cast, overload
+from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
from core.entities import PluginCredentialType
@@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
-from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
-from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
+from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from models.provider import ProviderType
logger = logging.getLogger(__name__)
+P = ParamSpec("P")
+R = TypeVar("R")
class ModelInstance:
@@ -168,7 +170,7 @@ class ModelInstance:
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -193,7 +195,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
+ self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
prompt_messages=list(prompt_messages),
@@ -213,7 +215,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -235,7 +237,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
@@ -252,7 +254,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
return self._round_robin_invoke(
- function=self.model_type_instance.get_num_tokens,
+ self.model_type_instance.get_num_tokens,
model=self.model_name,
credentials=self.credentials,
texts=texts,
@@ -277,7 +279,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -305,7 +307,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke_multimodal_rerank,
+ self.model_type_instance.invoke_multimodal_rerank,
model=self.model_name,
credentials=self.credentials,
query=query,
@@ -324,7 +326,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
text=text,
@@ -340,7 +342,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
file=file,
@@ -357,14 +359,14 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self._round_robin_invoke(
- function=self.model_type_instance.invoke,
+ self.model_type_instance.invoke,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
voice=voice,
)
- def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
+ def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""
Round-robin invoke
:param function: function to invoke
diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py
index fda00ac3b9..d78ce90aa1 100644
--- a/api/core/ops/entities/config_entity.py
+++ b/api/core/ops/entities/config_entity.py
@@ -1,8 +1,8 @@
from enum import StrEnum
-from pydantic import BaseModel, ValidationInfo, field_validator
+from pydantic import BaseModel
-from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path
+from core.ops.utils import validate_project_name, validate_url
class TracingProviderEnum(StrEnum):
@@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel):
return validate_project_name(v, default_name)
-class ArizeConfig(BaseTracingConfig):
- """
- Model class for Arize tracing config.
- """
-
- api_key: str | None = None
- space_id: str | None = None
- project: str | None = None
- endpoint: str = "https://otlp.arize.com"
-
- @field_validator("project")
- @classmethod
- def project_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "default")
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- return cls.validate_endpoint_url(v, "https://otlp.arize.com")
-
-
-class PhoenixConfig(BaseTracingConfig):
- """
- Model class for Phoenix tracing config.
- """
-
- api_key: str | None = None
- project: str | None = None
- endpoint: str = "https://app.phoenix.arize.com"
-
- @field_validator("project")
- @classmethod
- def project_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "default")
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- return validate_url_with_path(v, "https://app.phoenix.arize.com")
-
-
-class LangfuseConfig(BaseTracingConfig):
- """
- Model class for Langfuse tracing config.
- """
-
- public_key: str
- secret_key: str
- host: str = "https://api.langfuse.com"
-
- @field_validator("host")
- @classmethod
- def host_validator(cls, v, info: ValidationInfo):
- return validate_url_with_path(v, "https://api.langfuse.com")
-
-
-class LangSmithConfig(BaseTracingConfig):
- """
- Model class for Langsmith tracing config.
- """
-
- api_key: str
- project: str
- endpoint: str = "https://api.smith.langchain.com"
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- # LangSmith only allows HTTPS
- return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
-
-
-class OpikConfig(BaseTracingConfig):
- """
- Model class for Opik tracing config.
- """
-
- api_key: str | None = None
- project: str | None = None
- workspace: str | None = None
- url: str = "https://www.comet.com/opik/api/"
-
- @field_validator("project")
- @classmethod
- def project_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "Default Project")
-
- @field_validator("url")
- @classmethod
- def url_validator(cls, v, info: ValidationInfo):
- return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
-
-
-class WeaveConfig(BaseTracingConfig):
- """
- Model class for Weave tracing config.
- """
-
- api_key: str
- entity: str | None = None
- project: str
- endpoint: str = "https://trace.wandb.ai"
- host: str | None = None
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- # Weave only allows HTTPS for endpoint
- return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
-
- @field_validator("host")
- @classmethod
- def host_validator(cls, v, info: ValidationInfo):
- if v is not None and v.strip() != "":
- return validate_url(v, v, allowed_schemes=("https", "http"))
- return v
-
-
-class AliyunConfig(BaseTracingConfig):
- """
- Model class for Aliyun tracing config.
- """
-
- app_name: str = "dify_app"
- license_key: str
- endpoint: str
-
- @field_validator("app_name")
- @classmethod
- def app_name_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "dify_app")
-
- @field_validator("license_key")
- @classmethod
- def license_key_validator(cls, v, info: ValidationInfo):
- if not v or v.strip() == "":
- raise ValueError("License key cannot be empty")
- return v
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- # aliyun uses two URL formats, which may include a URL path
- return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
-
-
-class TencentConfig(BaseTracingConfig):
- """
- Tencent APM tracing config
- """
-
- token: str
- endpoint: str
- service_name: str
-
- @field_validator("token")
- @classmethod
- def token_validator(cls, v, info: ValidationInfo):
- if not v or v.strip() == "":
- raise ValueError("Token cannot be empty")
- return v
-
- @field_validator("endpoint")
- @classmethod
- def endpoint_validator(cls, v, info: ValidationInfo):
- return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
-
- @field_validator("service_name")
- @classmethod
- def service_name_validator(cls, v, info: ValidationInfo):
- return cls.validate_project_field(v, "dify_app")
-
-
-class MLflowConfig(BaseTracingConfig):
- """
- Model class for MLflow tracing config.
- """
-
- tracking_uri: str = "http://localhost:5000"
- experiment_id: str = "0" # Default experiment id in MLflow is 0
- username: str | None = None
- password: str | None = None
-
- @field_validator("tracking_uri")
- @classmethod
- def tracking_uri_validator(cls, v, info: ValidationInfo):
- if isinstance(v, str) and v.startswith("databricks"):
- raise ValueError(
- "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
- )
- return validate_url_with_path(v, "http://localhost:5000")
-
- @field_validator("experiment_id")
- @classmethod
- def experiment_id_validator(cls, v, info: ValidationInfo):
- return validate_integer_id(v)
-
-
-class DatabricksConfig(BaseTracingConfig):
- """
- Model class for Databricks (Databricks-managed MLflow) tracing config.
- """
-
- experiment_id: str
- host: str
- client_id: str | None = None
- client_secret: str | None = None
- personal_access_token: str | None = None
-
- @field_validator("experiment_id")
- @classmethod
- def experiment_id_validator(cls, v, info: ValidationInfo):
- return validate_integer_id(v)
-
-
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index cd63951537..e7ba6e502b 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict):
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
- match provider:
- case TracingProviderEnum.LANGFUSE:
- from core.ops.entities.config_entity import LangfuseConfig
- from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
+ try:
+ match provider:
+ case TracingProviderEnum.LANGFUSE:
+ from dify_trace_langfuse.config import LangfuseConfig
+ from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
- return {
- "config_class": LangfuseConfig,
- "secret_keys": ["public_key", "secret_key"],
- "other_keys": ["host", "project_key"],
- "trace_instance": LangFuseDataTrace,
- }
+ return {
+ "config_class": LangfuseConfig,
+ "secret_keys": ["public_key", "secret_key"],
+ "other_keys": ["host", "project_key"],
+ "trace_instance": LangFuseDataTrace,
+ }
- case TracingProviderEnum.LANGSMITH:
- from core.ops.entities.config_entity import LangSmithConfig
- from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
+ case TracingProviderEnum.LANGSMITH:
+ from dify_trace_langsmith.config import LangSmithConfig
+ from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
- return {
- "config_class": LangSmithConfig,
- "secret_keys": ["api_key"],
- "other_keys": ["project", "endpoint"],
- "trace_instance": LangSmithDataTrace,
- }
+ return {
+ "config_class": LangSmithConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "endpoint"],
+ "trace_instance": LangSmithDataTrace,
+ }
- case TracingProviderEnum.OPIK:
- from core.ops.entities.config_entity import OpikConfig
- from core.ops.opik_trace.opik_trace import OpikDataTrace
+ case TracingProviderEnum.OPIK:
+ from dify_trace_opik.config import OpikConfig
+ from dify_trace_opik.opik_trace import OpikDataTrace
- return {
- "config_class": OpikConfig,
- "secret_keys": ["api_key"],
- "other_keys": ["project", "url", "workspace"],
- "trace_instance": OpikDataTrace,
- }
+ return {
+ "config_class": OpikConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "url", "workspace"],
+ "trace_instance": OpikDataTrace,
+ }
- case TracingProviderEnum.WEAVE:
- from core.ops.entities.config_entity import WeaveConfig
- from core.ops.weave_trace.weave_trace import WeaveDataTrace
+ case TracingProviderEnum.WEAVE:
+ from dify_trace_weave.config import WeaveConfig
+ from dify_trace_weave.weave_trace import WeaveDataTrace
- return {
- "config_class": WeaveConfig,
- "secret_keys": ["api_key"],
- "other_keys": ["project", "entity", "endpoint", "host"],
- "trace_instance": WeaveDataTrace,
- }
- case TracingProviderEnum.ARIZE:
- from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
- from core.ops.entities.config_entity import ArizeConfig
+ return {
+ "config_class": WeaveConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "entity", "endpoint", "host"],
+ "trace_instance": WeaveDataTrace,
+ }
+ case TracingProviderEnum.ARIZE:
+ from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
+ from dify_trace_arize_phoenix.config import ArizeConfig
- return {
- "config_class": ArizeConfig,
- "secret_keys": ["api_key", "space_id"],
- "other_keys": ["project", "endpoint"],
- "trace_instance": ArizePhoenixDataTrace,
- }
- case TracingProviderEnum.PHOENIX:
- from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
- from core.ops.entities.config_entity import PhoenixConfig
+ return {
+ "config_class": ArizeConfig,
+ "secret_keys": ["api_key", "space_id"],
+ "other_keys": ["project", "endpoint"],
+ "trace_instance": ArizePhoenixDataTrace,
+ }
+ case TracingProviderEnum.PHOENIX:
+ from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace
+ from dify_trace_arize_phoenix.config import PhoenixConfig
- return {
- "config_class": PhoenixConfig,
- "secret_keys": ["api_key"],
- "other_keys": ["project", "endpoint"],
- "trace_instance": ArizePhoenixDataTrace,
- }
- case TracingProviderEnum.ALIYUN:
- from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
- from core.ops.entities.config_entity import AliyunConfig
+ return {
+ "config_class": PhoenixConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "endpoint"],
+ "trace_instance": ArizePhoenixDataTrace,
+ }
+ case TracingProviderEnum.ALIYUN:
+ from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
+ from dify_trace_aliyun.config import AliyunConfig
- return {
- "config_class": AliyunConfig,
- "secret_keys": ["license_key"],
- "other_keys": ["endpoint", "app_name"],
- "trace_instance": AliyunDataTrace,
- }
- case TracingProviderEnum.MLFLOW:
- from core.ops.entities.config_entity import MLflowConfig
- from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
+ return {
+ "config_class": AliyunConfig,
+ "secret_keys": ["license_key"],
+ "other_keys": ["endpoint", "app_name"],
+ "trace_instance": AliyunDataTrace,
+ }
+ case TracingProviderEnum.MLFLOW:
+ from dify_trace_mlflow.config import MLflowConfig
+ from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
- return {
- "config_class": MLflowConfig,
- "secret_keys": ["password"],
- "other_keys": ["tracking_uri", "experiment_id", "username"],
- "trace_instance": MLflowDataTrace,
- }
- case TracingProviderEnum.DATABRICKS:
- from core.ops.entities.config_entity import DatabricksConfig
- from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace
+ return {
+ "config_class": MLflowConfig,
+ "secret_keys": ["password"],
+ "other_keys": ["tracking_uri", "experiment_id", "username"],
+ "trace_instance": MLflowDataTrace,
+ }
+ case TracingProviderEnum.DATABRICKS:
+ from dify_trace_mlflow.config import DatabricksConfig
+ from dify_trace_mlflow.mlflow_trace import MLflowDataTrace
- return {
- "config_class": DatabricksConfig,
- "secret_keys": ["personal_access_token", "client_secret"],
- "other_keys": ["host", "client_id", "experiment_id"],
- "trace_instance": MLflowDataTrace,
- }
+ return {
+ "config_class": DatabricksConfig,
+ "secret_keys": ["personal_access_token", "client_secret"],
+ "other_keys": ["host", "client_id", "experiment_id"],
+ "trace_instance": MLflowDataTrace,
+ }
- case TracingProviderEnum.TENCENT:
- from core.ops.entities.config_entity import TencentConfig
- from core.ops.tencent_trace.tencent_trace import TencentDataTrace
+ case TracingProviderEnum.TENCENT:
+ from dify_trace_tencent.config import TencentConfig
+ from dify_trace_tencent.tencent_trace import TencentDataTrace
- return {
- "config_class": TencentConfig,
- "secret_keys": ["token"],
- "other_keys": ["endpoint", "service_name"],
- "trace_instance": TencentDataTrace,
- }
+ return {
+ "config_class": TencentConfig,
+ "secret_keys": ["token"],
+ "other_keys": ["endpoint", "service_name"],
+ "trace_instance": TencentDataTrace,
+ }
- case _:
- raise KeyError(f"Unsupported tracing provider: {provider}")
+ case _:
+ raise KeyError(f"Unsupported tracing provider: {provider}")
+ except ImportError:
+ raise ImportError(f"Provider {provider} is not installed.")
provider_config_map = OpsTraceProviderConfigMap()
diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py
index e3fba4ef3a..4e66d58b5e 100644
--- a/api/core/plugin/impl/model_runtime.py
+++ b/api/core/plugin/impl/model_runtime.py
@@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
- provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
+ provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
- provider_schema.icon_small_dark.zh_Hans
+ provider_schema.icon_small_dark.zh_hans
if lang.lower() == "zh_hans"
- else provider_schema.icon_small_dark.en_US
+ else provider_schema.icon_small_dark.en_us
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py
index 8f1d51f08a..7c6280fe93 100644
--- a/api/core/prompt/agent_history_prompt_transform.py
+++ b/api/core/prompt/agent_history_prompt_transform.py
@@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
class AgentHistoryPromptTransform(PromptTransform):
diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py
index 4926f44f16..a9995778f7 100644
--- a/api/core/rag/embedding/cached_embedding.py
+++ b/api/core/rag/embedding/cached_embedding.py
@@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from models.dataset import Embedding
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index 052fca930d..0330a43b28 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -3,6 +3,7 @@
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
"""
+import inspect
import logging
import mimetypes
import os
@@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
+ _closed: bool
+
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
+ self._closed = False
self.file_path = file_path
self.tenant_id = tenant_id
self.user_id = user_id
@@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
+ def close(self) -> None:
+ """Best-effort cleanup for downloaded temporary files."""
+ if getattr(self, "_closed", False):
+ return
+
+ self._closed = True
+ temp_file = getattr(self, "temp_file", None)
+ if temp_file is None:
+ return
+
+ try:
+ close_result = temp_file.close()
+ if inspect.isawaitable(close_result):
+ close_awaitable = getattr(close_result, "close", None)
+ if callable(close_awaitable):
+ close_awaitable()
+ except Exception:
+ logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
+
def __del__(self):
- if hasattr(self, "temp_file"):
- self.temp_file.close()
+ self.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py
index f8242efe31..7ffa9afafd 100644
--- a/api/core/rag/index_processor/processor/paragraph_index_processor.py
+++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py
@@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 1453fe020b..5631b3a921 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.helper import parse_uuid_str_or_none
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
@@ -517,11 +517,11 @@ class DatasetRetrieval:
if attachments_with_bindings:
for _, upload_file in attachments_with_bindings:
attachment_info = File(
- id=upload_file.id,
+ file_id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
reference=build_file_reference(
diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
index 2581c354dd..66b375dad1 100644
--- a/api/core/rag/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -9,7 +9,7 @@ from typing import Any, Literal
from core.model_manager import ModelInstance
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
-from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
+from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py
index 02625e242f..740d727e26 100644
--- a/api/core/repositories/human_input_repository.py
+++ b/api/core/repositories/human_input_repository.py
@@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from core.db.session_factory import session_factory
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
EmailDeliveryMethod,
diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py
index b3424cd9a5..c87e8a3ae0 100644
--- a/api/core/tools/tool_file_manager.py
+++ b/api/core/tools/tool_file_manager.py
@@ -28,7 +28,7 @@ class ToolFileManager:
def _build_graph_file_reference(tool_file: ToolFile) -> File:
extension = guess_extension(tool_file.mimetype) or ".bin"
return File(
- type=get_file_type_by_mime_type(tool_file.mimetype),
+ file_type=get_file_type_by_mime_type(tool_file.mimetype),
transfer_method=FileTransferMethod.TOOL_FILE,
remote_url=tool_file.original_url,
reference=build_file_reference(record_id=str(tool_file.id)),
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index f4588904d3..87cf6d7085 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -1082,7 +1082,12 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
- variable = variable_pool.get(tool_input.value)
+ variable_selector = tool_input.value
+ if not isinstance(variable_selector, list) or not all(
+ isinstance(selector_part, str) for selector_part in variable_selector
+ ):
+ raise ToolParameterError("Variable tool input must be a variable selector")
+ variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py
index 9e1d41cb39..a3623d4ecd 100644
--- a/api/core/tools/utils/model_invocation_utils.py
+++ b/api/core/tools/utils/model_invocation_utils.py
@@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from models.tools import ToolModelInvoke
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 52ab605963..cd8c6352b5 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -357,7 +357,10 @@ class WorkflowTool(Tool):
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
- transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
+ transfer_method_value = file_dict.get("transfer_method")
+ if not isinstance(transfer_method_value, str):
+ raise ValueError("Workflow file mapping is missing a valid transfer_method")
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
match transfer_method:
case FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_id
diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py
similarity index 74%
rename from api/core/workflow/human_input_compat.py
rename to api/core/workflow/human_input_adapter.py
index 75a0a0c202..4b765e6aea 100644
--- a/api/core/workflow/human_input_compat.py
+++ b/api/core/workflow/human_input_adapter.py
@@ -1,8 +1,8 @@
-"""Workflow-layer adapters for legacy human-input payload keys.
+"""Workflow-to-Graphon adapters for persisted node payloads.
-Stored workflow graphs and editor payloads may still use Dify-specific human
-input recipient keys. Normalize them here before handing configs to
-`graphon` so graph-owned models only see graph-neutral field names.
+Stored workflow graphs and editor payloads still contain a small set of
+Dify-owned field spellings and value shapes. Adapt them here before handing the
+payload to Graphon so Graphon-owned models only see current contracts.
"""
from __future__ import annotations
@@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
return None
-def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
@@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
- normalized = normalize_human_input_node_data_for_graph(node_data)
+ normalized = adapt_human_input_node_data_for_graph(node_data)
raw_delivery_methods = normalized.get("delivery_methods")
if not isinstance(raw_delivery_methods, list):
return []
@@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
return False
-def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_data)
if normalized is None:
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
- if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
- return normalized
- return normalize_human_input_node_data_for_graph(normalized)
+ node_type = normalized.get("type")
+ if node_type == BuiltinNodeTypes.HUMAN_INPUT:
+ return adapt_human_input_node_data_for_graph(normalized)
+ if node_type == BuiltinNodeTypes.TOOL:
+ return _adapt_tool_node_data_for_graph(normalized)
+ return normalized
-def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
+def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
normalized = _copy_mapping(node_config)
if normalized is None:
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
@@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
if data_mapping is None:
return normalized
- normalized["data"] = normalize_node_data_for_graph(data_mapping)
+ normalized["data"] = adapt_node_data_for_graph(data_mapping)
return normalized
+def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
+ normalized = dict(node_data)
+
+ raw_tool_configurations = normalized.get("tool_configurations")
+ if not isinstance(raw_tool_configurations, Mapping):
+ return normalized
+
+ existing_tool_parameters = normalized.get("tool_parameters")
+ normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
+ normalized_tool_configurations: dict[str, Any] = {}
+ found_legacy_tool_inputs = False
+
+ for name, value in raw_tool_configurations.items():
+ if not isinstance(value, Mapping):
+ normalized_tool_configurations[name] = value
+ continue
+
+ input_type = value.get("type")
+ input_value = value.get("value")
+ if input_type not in {"mixed", "variable", "constant"}:
+ normalized_tool_configurations[name] = value
+ continue
+
+ found_legacy_tool_inputs = True
+ normalized_tool_parameters.setdefault(name, dict(value))
+
+ flattened_value = _flatten_legacy_tool_configuration_value(
+ input_type=input_type,
+ input_value=input_value,
+ )
+ if flattened_value is not None:
+ normalized_tool_configurations[name] = flattened_value
+
+ if not found_legacy_tool_inputs:
+ return normalized
+
+ normalized["tool_parameters"] = normalized_tool_parameters
+ normalized["tool_configurations"] = normalized_tool_configurations
+ return normalized
+
+
+def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
+ if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
+ return input_value
+
+ if (
+ input_type == "variable"
+ and isinstance(input_value, list)
+ and all(isinstance(item, str) for item in input_value)
+ ):
+ return "{{#" + ".".join(input_value) + "#}}"
+
+ return None
+
+
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
normalized = dict(recipients)
@@ -291,9 +349,9 @@ __all__ = [
"MemberRecipient",
"WebAppDeliveryMethod",
"_WebAppDeliveryConfig",
+ "adapt_human_input_node_data_for_graph",
+ "adapt_node_config_for_graph",
+ "adapt_node_data_for_graph",
"is_human_input_webapp_enabled",
- "normalize_human_input_node_data_for_graph",
- "normalize_node_config_for_graph",
- "normalize_node_data_for_graph",
"parse_human_input_delivery_methods",
]
diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py
index 351da3444f..de4eae1b22 100644
--- a/api/core/workflow/node_factory.py
+++ b/api/core/workflow/node_factory.py
@@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
)
-from core.helper.ssrf_proxy import ssrf_proxy
+from core.helper.ssrf_proxy import graphon_ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.trigger.constants import TRIGGER_NODE_TYPES
-from core.workflow.human_input_compat import normalize_node_config_for_graph
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
@@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.file.file_manager import file_manager
from graphon.graph.graph import NodeFactory
from graphon.model_runtime.memory import PromptMessageMemory
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.base.node import Node
from graphon.nodes.code.code_node import WorkflowCodeExecutor
from graphon.nodes.code.entities import CodeLanguage
@@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
+ """Resolve the production node class for the requested type/version."""
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
@@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
)
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
- self._http_request_http_client = ssrf_proxy
+ self._http_request_http_client = graphon_ssrf_proxy
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
self._dify_context,
conversation_id_getter=self._conversation_id,
@@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
(including pydantic ValidationError, which subclasses ValueError),
if node type is unknown, or if no implementation exists for the resolved version
"""
- typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
+ typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
+ # Graph configs are initially validated against permissive shared node data.
+ # Re-validate using the resolved node class so workflow-local node schemas
+ # stay explicit and constructors receive the concrete typed payload.
+ resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
@@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=True,
include_llm_file_saver=True,
@@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
- node_data=node_data,
+ node_data=resolved_node_data,
wrap_model_instance=True,
include_http_client=False,
include_llm_file_saver=False,
@@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**node_init_kwargs,
@@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
"""
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
"""
- return node_class.validate_node_data(node_data)
+ validate_node_data = getattr(node_class, "validate_node_data", None)
+ if callable(validate_node_data):
+ return cast("BaseNodeData", validate_node_data(node_data))
+ return node_data
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py
index 2e632e56f0..b8725853c4 100644
--- a/api/core/workflow/node_runtime.py
+++ b/api/core/workflow/node_runtime.py
@@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable, Generator, Mapping, Sequence
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any, Literal, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.llm.runtime_protocols import (
PreparedLLMProtocol,
@@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
from models.model import UploadFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
-from .human_input_compat import (
+from .human_input_adapter import (
BoundRecipient,
DeliveryChannelConfig,
DeliveryMethodType,
@@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
return self._model_instance.get_llm_num_tokens(prompt_messages)
+ @overload
+ def invoke_llm(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ model_parameters: Mapping[str, Any],
+ tools: Sequence[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: Literal[False],
+ ) -> LLMResult: ...
+
+ @overload
+ def invoke_llm(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ model_parameters: Mapping[str, Any],
+ tools: Sequence[PromptMessageTool] | None,
+ stop: Sequence[str] | None,
+ stream: Literal[True],
+ ) -> Generator[LLMResultChunk, None, None]: ...
+
def invoke_llm(
self,
*,
@@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
stream=stream,
)
+ @overload
+ def invoke_llm_with_structured_output(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Mapping[str, Any],
+ stop: Sequence[str] | None,
+ stream: Literal[False],
+ ) -> LLMResultWithStructuredOutput: ...
+
+ @overload
+ def invoke_llm_with_structured_output(
+ self,
+ *,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Mapping[str, Any],
+ stop: Sequence[str] | None,
+ stream: Literal[True],
+ ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
+
def invoke_llm_with_structured_output(
self,
*,
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 7b000101b0..68a24e86b1 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
@@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: AgentNodeData,
+ *,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
- *,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py
index e4f6b3b470..f3006c4242 100644
--- a/api/core/workflow/nodes/datasource/datasource_node.py
+++ b/api/core/workflow/nodes/datasource/datasource_node.py
@@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.file_reference import resolve_file_record_id
from core.workflow.system_variables import SystemVariableKey, get_system_segment
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
NodeExecutionType,
@@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: DatasourceNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- ):
+ ) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
index d5cab05dbe..9c1b7ab2c4 100644
--- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
+++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py
@@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
from core.rag.summary_index.summary_index import SummaryIndex
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
@@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: KnowledgeIndexNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
) -> None:
- super().__init__(id, config, graph_init_params, graph_runtime_state)
+ super().__init__(
+ node_id=node_id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
self.index_processor = IndexProcessor()
self.summary_index_service = SummaryIndex()
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index 47ad14b499..25f73e446d 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.file_reference import parse_file_reference
from graphon.entities import GraphInitParams
-from graphon.entities.graph_config import NodeConfigDict
from graphon.enums import (
BuiltinNodeTypes,
WorkflowNodeExecutionMetadataKey,
@@ -50,6 +49,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
+ if value is None or isinstance(value, (str, float)):
+ return value
+ if isinstance(value, int) and not isinstance(value, bool):
+ return value
+ return str(value)
+
+
+def _normalize_metadata_filter_sequence_item(value: object) -> str:
+ return value if isinstance(value, str) else str(value)
+
+
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
@@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
- id: str,
- config: NodeConfigDict,
+ node_id: str,
+ config: KnowledgeRetrievalNodeData,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
- ):
+ ) -> None:
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
+ resolved_value: str | Sequence[str] | int | float | None
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
- resolved_value = segment_group.value[0].to_object()
+ resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
- resolved_values = []
- for v in value: # type: ignore
+ resolved_values: list[str] = []
+ for v in value:
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
- resolved_values.append(segment_group.value[0].to_object())
+ resolved_values.append(
+ _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
+ )
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py
index 288d37d265..1d2ad4d445 100644
--- a/api/factories/file_factory/builders.py
+++ b/api/factories/file_factory/builders.py
@@ -10,8 +10,8 @@ from typing import Any
from sqlalchemy import select
from core.app.file_access import FileAccessControllerProtocol
+from core.db.session_factory import session_factory
from core.workflow.file_reference import build_file_reference
-from extensions.ext_database import db
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
from models import ToolFile, UploadFile
@@ -135,29 +135,30 @@ def _build_from_local_file(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
- row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
- if row is None:
- raise ValueError("Invalid upload file")
+ with session_factory.create_session() as session:
+ row = session.scalar(access_controller.apply_upload_file_filters(stmt))
+ if row is None:
+ raise ValueError("Invalid upload file")
- detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
- file_type = _resolve_file_type(
- detected_file_type=detected_file_type,
- specified_type=mapping.get("type", "custom"),
- strict_type_validation=strict_type_validation,
- )
+ detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
+ file_type = _resolve_file_type(
+ detected_file_type=detected_file_type,
+ specified_type=mapping.get("type", "custom"),
+ strict_type_validation=strict_type_validation,
+ )
- return File(
- id=mapping.get("id"),
- filename=row.name,
- extension="." + row.extension,
- mime_type=row.mime_type,
- type=file_type,
- transfer_method=transfer_method,
- remote_url=row.source_url,
- reference=build_file_reference(record_id=str(row.id)),
- size=row.size,
- storage_key=row.key,
- )
+ return File(
+ file_id=mapping.get("id"),
+ filename=row.name,
+ extension="." + row.extension,
+ mime_type=row.mime_type,
+ file_type=file_type,
+ transfer_method=transfer_method,
+ remote_url=row.source_url,
+ reference=build_file_reference(record_id=str(row.id)),
+ size=row.size,
+ storage_key=row.key,
+ )
def _build_from_remote_url(
@@ -179,32 +180,33 @@ def _build_from_remote_url(
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)
- upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
- if upload_file is None:
- raise ValueError("Invalid upload file")
+ with session_factory.create_session() as session:
+ upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
+ if upload_file is None:
+ raise ValueError("Invalid upload file")
- detected_file_type = standardize_file_type(
- extension="." + upload_file.extension,
- mime_type=upload_file.mime_type,
- )
- file_type = _resolve_file_type(
- detected_file_type=detected_file_type,
- specified_type=mapping.get("type"),
- strict_type_validation=strict_type_validation,
- )
+ detected_file_type = standardize_file_type(
+ extension="." + upload_file.extension,
+ mime_type=upload_file.mime_type,
+ )
+ file_type = _resolve_file_type(
+ detected_file_type=detected_file_type,
+ specified_type=mapping.get("type"),
+ strict_type_validation=strict_type_validation,
+ )
- return File(
- id=mapping.get("id"),
- filename=upload_file.name,
- extension="." + upload_file.extension,
- mime_type=upload_file.mime_type,
- type=file_type,
- transfer_method=transfer_method,
- remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
- reference=build_file_reference(record_id=str(upload_file.id)),
- size=upload_file.size,
- storage_key=upload_file.key,
- )
+ return File(
+ file_id=mapping.get("id"),
+ filename=upload_file.name,
+ extension="." + upload_file.extension,
+ mime_type=upload_file.mime_type,
+ file_type=file_type,
+ transfer_method=transfer_method,
+ remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
+ reference=build_file_reference(record_id=str(upload_file.id)),
+ size=upload_file.size,
+ storage_key=upload_file.key,
+ )
url = mapping.get("url") or mapping.get("remote_url")
if not url:
@@ -220,9 +222,9 @@ def _build_from_remote_url(
)
return File(
- id=mapping.get("id"),
+ file_id=mapping.get("id"),
filename=filename,
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
@@ -247,30 +249,31 @@ def _build_from_tool_file(
ToolFile.id == tool_file_id,
ToolFile.tenant_id == tenant_id,
)
- tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
- if tool_file is None:
- raise ValueError(f"ToolFile {tool_file_id} not found")
+ with session_factory.create_session() as session:
+ tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt))
+ if tool_file is None:
+ raise ValueError(f"ToolFile {tool_file_id} not found")
- extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
- detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
- file_type = _resolve_file_type(
- detected_file_type=detected_file_type,
- specified_type=mapping.get("type"),
- strict_type_validation=strict_type_validation,
- )
+ extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
+ detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
+ file_type = _resolve_file_type(
+ detected_file_type=detected_file_type,
+ specified_type=mapping.get("type"),
+ strict_type_validation=strict_type_validation,
+ )
- return File(
- id=mapping.get("id"),
- filename=tool_file.name,
- type=file_type,
- transfer_method=transfer_method,
- remote_url=tool_file.original_url,
- reference=build_file_reference(record_id=str(tool_file.id)),
- extension=extension,
- mime_type=tool_file.mimetype,
- size=tool_file.size,
- storage_key=tool_file.file_key,
- )
+ return File(
+ file_id=mapping.get("id"),
+ filename=tool_file.name,
+ file_type=file_type,
+ transfer_method=transfer_method,
+ remote_url=tool_file.original_url,
+ reference=build_file_reference(record_id=str(tool_file.id)),
+ extension=extension,
+ mime_type=tool_file.mimetype,
+ size=tool_file.size,
+ storage_key=tool_file.file_key,
+ )
def _build_from_datasource_file(
@@ -289,31 +292,32 @@ def _build_from_datasource_file(
UploadFile.id == datasource_file_id,
UploadFile.tenant_id == tenant_id,
)
- datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
- if datasource_file is None:
- raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
+ with session_factory.create_session() as session:
+ datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
+ if datasource_file is None:
+ raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
- extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
- detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
- file_type = _resolve_file_type(
- detected_file_type=detected_file_type,
- specified_type=mapping.get("type"),
- strict_type_validation=strict_type_validation,
- )
+ extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
+ detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
+ file_type = _resolve_file_type(
+ detected_file_type=detected_file_type,
+ specified_type=mapping.get("type"),
+ strict_type_validation=strict_type_validation,
+ )
- return File(
- id=mapping.get("datasource_file_id"),
- filename=datasource_file.name,
- type=file_type,
- transfer_method=FileTransferMethod.TOOL_FILE,
- remote_url=datasource_file.source_url,
- reference=build_file_reference(record_id=str(datasource_file.id)),
- extension=extension,
- mime_type=datasource_file.mime_type,
- size=datasource_file.size,
- storage_key=datasource_file.key,
- url=datasource_file.source_url,
- )
+ return File(
+ file_id=mapping.get("datasource_file_id"),
+ filename=datasource_file.name,
+ file_type=file_type,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ remote_url=datasource_file.source_url,
+ reference=build_file_reference(record_id=str(datasource_file.id)),
+ extension=extension,
+ mime_type=datasource_file.mime_type,
+ size=datasource_file.size,
+ storage_key=datasource_file.key,
+ url=datasource_file.source_url,
+ )
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:
diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py
index b5acbbbcb4..d518114777 100644
--- a/api/fields/_value_type_serializer.py
+++ b/api/fields/_value_type_serializer.py
@@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
- return v.value_type.exposed_type().value
+ return str(v.value_type.exposed_type())
else:
value_type = v.get("value_type")
if value_type is None:
raise ValueError("value_type is required but not provided")
- return value_type.exposed_type().value
+ return str(value_type.exposed_type())
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index cf4a71d545..e4219ba1ee 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
def _normalize_value_type(cls, value: Any) -> str:
exposed_type = getattr(value, "exposed_type", None)
if callable(exposed_type):
- return str(exposed_type().value)
+ return str(exposed_type())
if isinstance(value, str):
try:
- return str(SegmentType(value).exposed_type().value)
+ return str(SegmentType(value).exposed_type())
except ValueError:
return value
try:
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index f9b5e98936..6e947858ba 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
- "value_type": value.value_type.exposed_type().value,
+ "value_type": str(value.value_type.exposed_type()),
"description": value.description,
}
if isinstance(value, dict):
diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py
index 9b53918f24..934aacb45b 100644
--- a/api/libs/oauth_data_source.py
+++ b/api/libs/oauth_data_source.py
@@ -6,8 +6,8 @@ from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
+from core.db.session_factory import session_factory
from core.helper.http_client_pooling import get_pooled_http_client
-from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
- data_source_binding = db.session.scalar(
- select(DataSourceOauthBinding).where(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.access_token == access_token,
+ with session_factory.create_session() as session:
+ data_source_binding = session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.access_token == access_token,
+ )
)
- )
- if data_source_binding:
- data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
- data_source_binding.disabled = False
- data_source_binding.updated_at = naive_utc_now()
- db.session.commit()
- else:
- new_data_source_binding = DataSourceOauthBinding(
- tenant_id=current_user.current_tenant_id,
- access_token=access_token,
- source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
- provider="notion",
- )
- db.session.add(new_data_source_binding)
- db.session.commit()
+ if data_source_binding:
+ data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
+ data_source_binding.disabled = False
+ data_source_binding.updated_at = naive_utc_now()
+ session.commit()
+ else:
+ new_data_source_binding = DataSourceOauthBinding(
+ tenant_id=current_user.current_tenant_id,
+ access_token=access_token,
+ source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
+ provider="notion",
+ )
+ session.add(new_data_source_binding)
+ session.commit()
def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
@@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource):
pages=pages,
)
# save data source binding
- data_source_binding = db.session.scalar(
- select(DataSourceOauthBinding).where(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.access_token == access_token,
+ with session_factory.create_session() as session:
+ data_source_binding = session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.access_token == access_token,
+ )
)
- )
- if data_source_binding:
- data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
- data_source_binding.disabled = False
- data_source_binding.updated_at = naive_utc_now()
- db.session.commit()
- else:
- new_data_source_binding = DataSourceOauthBinding(
- tenant_id=current_user.current_tenant_id,
- access_token=access_token,
- source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
- provider="notion",
- )
- db.session.add(new_data_source_binding)
- db.session.commit()
+ if data_source_binding:
+ data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
+ data_source_binding.disabled = False
+ data_source_binding.updated_at = naive_utc_now()
+ session.commit()
+ else:
+ new_data_source_binding = DataSourceOauthBinding(
+ tenant_id=current_user.current_tenant_id,
+ access_token=access_token,
+ source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
+ provider="notion",
+ )
+ session.add(new_data_source_binding)
+ session.commit()
def sync_data_source(self, binding_id: str) -> None:
# save data source binding
- data_source_binding = db.session.scalar(
- select(DataSourceOauthBinding).where(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.id == binding_id,
- DataSourceOauthBinding.disabled == False,
+ with session_factory.create_session() as session:
+ data_source_binding = session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.id == binding_id,
+ DataSourceOauthBinding.disabled == False,
+ )
)
- )
- if data_source_binding:
- # get all authorized pages
- pages = self.get_authorized_pages(data_source_binding.access_token)
- source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
- new_source_info = self._build_source_info(
- workspace_name=source_info["workspace_name"],
- workspace_icon=source_info["workspace_icon"],
- workspace_id=source_info["workspace_id"],
- pages=pages,
- )
- data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
- data_source_binding.disabled = False
- data_source_binding.updated_at = naive_utc_now()
- db.session.commit()
- else:
- raise ValueError("Data source binding not found")
+ if data_source_binding:
+ # get all authorized pages
+ pages = self.get_authorized_pages(data_source_binding.access_token)
+ source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
+ new_source_info = self._build_source_info(
+ workspace_name=source_info["workspace_name"],
+ workspace_icon=source_info["workspace_icon"],
+ workspace_id=source_info["workspace_id"],
+ pages=pages,
+ )
+ data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
+ data_source_binding.disabled = False
+ data_source_binding.updated_at = naive_utc_now()
+ session.commit()
+ else:
+ raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 50301dd2d7..eee5c39a0e 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -1715,7 +1715,7 @@ class SegmentAttachmentBinding(TypeBase):
)
-class DocumentSegmentSummary(Base):
+class DocumentSegmentSummary(TypeBase):
__tablename__ = "document_segment_summaries"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"),
@@ -1725,25 +1725,40 @@ class DocumentSegmentSummary(Base):
sa.Index("document_segment_summaries_status_idx", "status"),
)
- id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
+ id: Mapped[str] = mapped_column(
+ StringUUID,
+ nullable=False,
+ insert_default=lambda: str(uuid4()),
+ default_factory=lambda: str(uuid4()),
+ init=False,
+ )
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# corresponds to DocumentSegment.id or parent chunk id
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
- summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
- summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
- summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
- tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
- status: Mapped[str] = mapped_column(
- EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'")
+ summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
+ tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None)
+ status: Mapped[SummaryStatus] = mapped_column(
+ EnumText(SummaryStatus, length=32),
+ nullable=False,
+ server_default=sa.text("'generating'"),
+ default=SummaryStatus.GENERATING,
+ )
+ error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
+ enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
+ disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
+ disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
+ created_at: Mapped[datetime] = mapped_column(
+ DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
- error: Mapped[str] = mapped_column(LongText, nullable=True)
- enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
- disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
- disabled_by = mapped_column(StringUUID, nullable=True)
- created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
- DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
+ DateTime,
+ nullable=False,
+ server_default=func.current_timestamp(),
+ onupdate=func.current_timestamp(),
+ init=False,
)
def __repr__(self):
diff --git a/api/models/human_input.py b/api/models/human_input.py
index b4c7a634b6..7447d3efcb 100644
--- a/api/models/human_input.py
+++ b/api/models/human_input.py
@@ -6,7 +6,7 @@ import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.orm import Mapped, mapped_column, relationship
-from core.workflow.human_input_compat import DeliveryMethodType
+from core.workflow.human_input_adapter import DeliveryMethodType
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.helper import generate_string
diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py
index a2dc8f6157..77dcbd13d4 100644
--- a/api/models/utils/file_input_compat.py
+++ b/api/models/utils/file_input_compat.py
@@ -5,7 +5,8 @@ from functools import lru_cache
from typing import Any
from core.workflow.file_reference import parse_file_reference
-from graphon.file import File, FileTransferMethod
+from graphon.file import File, FileTransferMethod, FileType
+from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
@lru_cache(maxsize=1)
@@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
return tenant_resolver()
+def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
+ """Build a graph `File` directly from serialized metadata."""
+
+ def _coerce_file_type(value: Any) -> FileType:
+ if isinstance(value, FileType):
+ return value
+ if isinstance(value, str):
+ return FileType.value_of(value)
+ raise ValueError("file type is required in file mapping")
+
+ mapping = dict(file_mapping)
+ transfer_method_value = mapping.get("transfer_method")
+ if isinstance(transfer_method_value, FileTransferMethod):
+ transfer_method = transfer_method_value
+ elif isinstance(transfer_method_value, str):
+ transfer_method = FileTransferMethod.value_of(transfer_method_value)
+ else:
+ raise ValueError("transfer_method is required in file mapping")
+
+ file_id = mapping.get("file_id")
+ if not isinstance(file_id, str) or not file_id:
+ legacy_id = mapping.get("id")
+ file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
+
+ related_id = resolve_file_record_id(mapping)
+ if related_id is None:
+ raw_related_id = mapping.get("related_id")
+ related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
+
+ remote_url = mapping.get("remote_url")
+ if not isinstance(remote_url, str) or not remote_url:
+ url = mapping.get("url")
+ remote_url = url if isinstance(url, str) and url else None
+
+ reference = mapping.get("reference")
+ if not isinstance(reference, str) or not reference:
+ reference = None
+
+ filename = mapping.get("filename")
+ if not isinstance(filename, str):
+ filename = None
+
+ extension = mapping.get("extension")
+ if not isinstance(extension, str):
+ extension = None
+
+ mime_type = mapping.get("mime_type")
+ if not isinstance(mime_type, str):
+ mime_type = None
+
+ size = mapping.get("size", -1)
+ if not isinstance(size, int):
+ size = -1
+
+ storage_key = mapping.get("storage_key")
+ if not isinstance(storage_key, str):
+ storage_key = None
+
+ tenant_id = mapping.get("tenant_id")
+ if not isinstance(tenant_id, str):
+ tenant_id = None
+
+ dify_model_identity = mapping.get("dify_model_identity")
+ if not isinstance(dify_model_identity, str):
+ dify_model_identity = FILE_MODEL_IDENTITY
+
+ tool_file_id = mapping.get("tool_file_id")
+ if not isinstance(tool_file_id, str):
+ tool_file_id = None
+
+ upload_file_id = mapping.get("upload_file_id")
+ if not isinstance(upload_file_id, str):
+ upload_file_id = None
+
+ datasource_file_id = mapping.get("datasource_file_id")
+ if not isinstance(datasource_file_id, str):
+ datasource_file_id = None
+
+ return File(
+ file_id=file_id,
+ tenant_id=tenant_id,
+ file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
+ transfer_method=transfer_method,
+ remote_url=remote_url,
+ reference=reference,
+ related_id=related_id,
+ filename=filename,
+ extension=extension,
+ mime_type=mime_type,
+ size=size,
+ storage_key=storage_key,
+ dify_model_identity=dify_model_identity,
+ url=remote_url,
+ tool_file_id=tool_file_id,
+ upload_file_id=upload_file_id,
+ datasource_file_id=datasource_file_id,
+ )
+
+
+def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
+ """Recursively rebuild serialized graph file payloads into `File` objects.
+
+ `graphon` 0.2.2 no longer accepts legacy serialized file mappings via
+ `model_validate_json()`. Dify keeps this recovery path at the model boundary
+ so historical JSON blobs remain readable without reintroducing global graph
+ patches or test-local coercion.
+ """
+ if isinstance(value, list):
+ return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
+
+ if isinstance(value, dict):
+ if maybe_file_object(value):
+ return build_file_from_mapping_without_lookup(file_mapping=value)
+ return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
+
+ return value
+
+
def build_file_from_stored_mapping(
*,
file_mapping: Mapping[str, Any],
@@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
- remote_url = mapping.get("remote_url")
- if not isinstance(remote_url, str) or not remote_url:
- url = mapping.get("url")
- if isinstance(url, str) and url:
- mapping["remote_url"] = url
- return File.model_validate(mapping)
+ return build_file_from_mapping_without_lookup(file_mapping=mapping)
return file_factory.build_from_mapping(
mapping=mapping,
diff --git a/api/models/workflow.py b/api/models/workflow.py
index dfda03c2ee..d127244b0f 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
-from core.workflow.human_input_compat import normalize_node_config_for_graph
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.variable_prefixes import (
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
from .engine import db
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
from .types import EnumText, LongText, StringUUID
-from .utils.file_input_compat import build_file_from_stored_mapping
+from .utils.file_input_compat import (
+ build_file_from_mapping_without_lookup,
+ build_file_from_stored_mapping,
+)
logger = logging.getLogger(__name__)
@@ -290,7 +293,7 @@ class Workflow(Base): # bug
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration:
raise NodeNotFoundError(node_id)
- return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
+ return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
@staticmethod
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
@@ -1688,7 +1691,7 @@ class WorkflowDraftVariable(Base):
return cast(Any, value)
normalized_file = dict(value)
normalized_file.pop("tenant_id", None)
- return File.model_validate(normalized_file)
+ return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
elif isinstance(value, list) and value:
value_list = cast(list[Any], value)
first: Any = value_list[0]
@@ -1698,7 +1701,7 @@ class WorkflowDraftVariable(Base):
for item in value_list:
normalized_file = dict(cast(dict[str, Any], item))
normalized_file.pop("tenant_id", None)
- file_list.append(File.model_validate(normalized_file))
+ file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
return cast(Any, file_list)
else:
return cast(Any, value)
diff --git a/api/providers/README.md b/api/providers/README.md
index a00ec8bc52..5d5e6db9af 100644
--- a/api/providers/README.md
+++ b/api/providers/README.md
@@ -10,3 +10,6 @@ This directory holds **optional workspace packages** that plug into Dify’s API
Provider tests often live next to the package, e.g. `providers///tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
+## Excluding Providers
+
+In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-` and `--group trace-` to enable specific providers.
diff --git a/api/providers/trace/README.md b/api/providers/trace/README.md
new file mode 100644
index 0000000000..a7ffa5ed26
--- /dev/null
+++ b/api/providers/trace/README.md
@@ -0,0 +1,78 @@
+# Trace providers
+
+This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others).
+
+Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping.
+
+## Architecture
+
+| Layer | Location | Role |
+|--------|----------|------|
+| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads |
+| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class |
+| Your package | `api/providers/trace/trace-/` | Pydantic config + subclass of `BaseTraceInstance` |
+
+At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype.
+
+## What you implement
+
+### 1. Config model (`BaseTracingConfig`)
+
+Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate.
+
+Fields fall into two groups used by the manager:
+
+- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords).
+- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints).
+
+List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct.
+
+### 2. Trace instance (`BaseTraceInstance`)
+
+Subclass `BaseTraceInstance` and implement:
+
+```python
+def trace(self, trace_info: BaseTraceInfo) -> None:
+ ...
+```
+
+Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including:
+
+- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace`
+- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo`
+- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo`
+
+You may ignore categories your backend does not support; existing providers often no-op unhandled types.
+
+Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context.
+
+### 3. Register in the API core
+
+Upstream changes are required so Dify knows your provider exists:
+
+1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`).
+2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning:
+ - `config_class`: your Pydantic config type
+ - `secret_keys` / `other_keys`: lists of field names as above
+ - `trace_instance`: your `BaseTraceInstance` subclass
+ Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`.
+
+If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app.
+
+## Package layout
+
+Each provider is a normal uv workspace member, for example:
+
+- `api/providers/trace/trace-/pyproject.toml` — project name `dify-trace-`, dependencies on vendor SDKs
+- `api/providers/trace/trace-/src/dify_trace_/` — `config.py`, `_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`.
+
+Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`.
+
+## Wiring into the `api` workspace
+
+In `api/pyproject.toml`:
+
+1. **`[tool.uv.sources]`** — `dify-trace- = { workspace = true }`
+2. **`[dependency-groups]`** — add `trace- = ["dify-trace-"]` and include `dify-trace-` in `trace-all` if it should ship with the default bundle
+
+After changing metadata, run **`uv sync`** from `api/`.
diff --git a/api/providers/trace/trace-aliyun/pyproject.toml b/api/providers/trace/trace-aliyun/pyproject.toml
new file mode 100644
index 0000000000..bcef7e9fb1
--- /dev/null
+++ b/api/providers/trace/trace-aliyun/pyproject.toml
@@ -0,0 +1,14 @@
+[project]
+name = "dify-trace-aliyun"
+version = "0.0.1"
+dependencies = [
+ # versions inherited from parent
+ "opentelemetry-api",
+ "opentelemetry-exporter-otlp-proto-grpc",
+ "opentelemetry-sdk",
+ "opentelemetry-semantic-conventions",
+]
+description = "Dify ops tracing provider (Aliyun)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py
similarity index 100%
rename from api/core/ops/aliyun_trace/__init__.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py
similarity index 98%
rename from api/core/ops/aliyun_trace/aliyun_trace.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py
index 76e81242f4..54d2f8167f 100644
--- a/api/core/ops/aliyun_trace/aliyun_trace.py
+++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py
@@ -4,7 +4,20 @@ from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
-from core.ops.aliyun_trace.data_exporter.traceclient import (
+from core.ops.base_trace_instance import BaseTraceInstance
+from core.ops.entities.trace_entity import (
+ BaseTraceInfo,
+ DatasetRetrievalTraceInfo,
+ GenerateNameTraceInfo,
+ MessageTraceInfo,
+ ModerationTraceInfo,
+ SuggestedQuestionTraceInfo,
+ ToolTraceInfo,
+ WorkflowTraceInfo,
+)
+from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_aliyun.config import AliyunConfig
+from dify_trace_aliyun.data_exporter.traceclient import (
TraceClient,
build_endpoint,
convert_datetime_to_nanoseconds,
@@ -12,8 +25,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
convert_to_trace_id,
generate_span_id,
)
-from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
-from core.ops.aliyun_trace.entities.semconv import (
+from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
+from dify_trace_aliyun.entities.semconv import (
DIFY_APP_ID,
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
@@ -32,7 +45,7 @@ from core.ops.aliyun_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
-from core.ops.aliyun_trace.utils import (
+from dify_trace_aliyun.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
@@ -44,19 +57,6 @@ from core.ops.aliyun_trace.utils import (
get_workflow_node_status,
serialize_json_data,
)
-from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import AliyunConfig
-from core.ops.entities.trace_entity import (
- BaseTraceInfo,
- DatasetRetrievalTraceInfo,
- GenerateNameTraceInfo,
- MessageTraceInfo,
- ModerationTraceInfo,
- SuggestedQuestionTraceInfo,
- ToolTraceInfo,
- WorkflowTraceInfo,
-)
-from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.entities import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py
new file mode 100644
index 0000000000..e0133e6cc9
--- /dev/null
+++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py
@@ -0,0 +1,32 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url_with_path
+
+
+class AliyunConfig(BaseTracingConfig):
+ """
+ Model class for Aliyun tracing config.
+ """
+
+ app_name: str = "dify_app"
+ license_key: str
+ endpoint: str
+
+ @field_validator("app_name")
+ @classmethod
+ def app_name_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "dify_app")
+
+ @field_validator("license_key")
+ @classmethod
+ def license_key_validator(cls, v, info: ValidationInfo):
+ if not v or v.strip() == "":
+ raise ValueError("License key cannot be empty")
+ return v
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ # aliyun uses two URL formats, which may include a URL path
+ return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py
similarity index 100%
rename from api/core/ops/aliyun_trace/data_exporter/__init__.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py
diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
similarity index 98%
rename from api/core/ops/aliyun_trace/data_exporter/traceclient.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
index 67d5163b0f..00aab6bf89 100644
--- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py
+++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
@@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
-from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
-from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE
+from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
+from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE
INVALID_SPAN_ID: Final[int] = 0x0000000000000000
INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000
diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py
similarity index 100%
rename from api/core/ops/aliyun_trace/entities/__init__.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py
diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py
similarity index 100%
rename from api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py
diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py
similarity index 100%
rename from api/core/ops/aliyun_trace/entities/semconv.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py
diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed
similarity index 100%
rename from api/core/ops/arize_phoenix_trace/__init__.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed
diff --git a/api/core/ops/aliyun_trace/utils.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py
similarity index 97%
rename from api/core/ops/aliyun_trace/utils.py
rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py
index 2e02a186cc..5678c66adb 100644
--- a/api/core/ops/aliyun_trace/utils.py
+++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py
@@ -4,7 +4,8 @@ from typing import Any, TypedDict
from opentelemetry.trace import Link, Status, StatusCode
-from core.ops.aliyun_trace.entities.semconv import (
+from core.rag.models.document import Document
+from dify_trace_aliyun.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
@@ -13,7 +14,6 @@ from core.ops.aliyun_trace.entities.semconv import (
OUTPUT_VALUE,
GenAISpanKind,
)
-from core.rag.models.document import Document
from extensions.ext_database import db
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
- from core.ops.aliyun_trace.data_exporter.traceclient import create_link
+ from dify_trace_aliyun.data_exporter.traceclient import create_link
links = []
if trace_id:
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py
similarity index 86%
rename from api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py
index acb43d4036..286dda419c 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py
@@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch
import httpx
import pytest
-from opentelemetry.sdk.trace import ReadableSpan
-from opentelemetry.trace import SpanKind, Status, StatusCode
-
-from core.ops.aliyun_trace.data_exporter.traceclient import (
+from dify_trace_aliyun.data_exporter.traceclient import (
INVALID_SPAN_ID,
SpanBuilder,
TraceClient,
@@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import (
create_link,
generate_span_id,
)
-from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
+from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData
+from opentelemetry.sdk.trace import ReadableSpan
+from opentelemetry.trace import SpanKind, Status, StatusCode
@pytest.fixture
@@ -41,8 +40,8 @@ def trace_client_factory():
class TestTraceClient:
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname")
def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory):
mock_gethostname.return_value = "test-host"
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -56,7 +55,7 @@ class TestTraceClient:
client.shutdown()
assert client.done is True
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -64,8 +63,8 @@ class TestTraceClient:
client.export(spans)
mock_exporter.export.assert_called_once_with(spans)
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory):
mock_response = MagicMock()
mock_response.status_code = 405
@@ -74,8 +73,8 @@ class TestTraceClient:
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.api_check() is True
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory):
mock_response = MagicMock()
mock_response.status_code = 500
@@ -84,8 +83,8 @@ class TestTraceClient:
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.api_check() is False
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory):
mock_head.side_effect = httpx.RequestError("Connection error")
@@ -93,12 +92,12 @@ class TestTraceClient:
with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"):
client.api_check()
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_get_project_url(self, mock_exporter_class, trace_client_factory):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm"
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_add_span(self, mock_exporter_class, trace_client_factory):
client = trace_client_factory(
service_name="test-service",
@@ -134,8 +133,8 @@ class TestTraceClient:
assert len(client.queue) == 2
mock_notify.assert_called_once()
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.logger")
def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory):
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1)
@@ -159,7 +158,7 @@ class TestTraceClient:
assert len(client.queue) == 1
mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.")
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_export_batch_error(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
mock_exporter.export.side_effect = Exception("Export failed")
@@ -168,11 +167,11 @@ class TestTraceClient:
mock_span = MagicMock(spec=ReadableSpan)
client.queue.append(mock_span)
- with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger:
+ with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger:
client._export_batch()
mock_logger.warning.assert_called()
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_worker_loop(self, mock_exporter_class, trace_client_factory):
# We need to test the wait timeout in _worker
# But _worker runs in a thread. Let's mock condition.wait.
@@ -189,7 +188,7 @@ class TestTraceClient:
# mock_wait might have been called
assert mock_wait.called or client.done
- @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter")
+ @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter")
def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory):
mock_exporter = mock_exporter_class.return_value
client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint")
@@ -268,7 +267,7 @@ def test_generate_span_id():
assert span_id != INVALID_SPAN_ID
# Test retry loop
- with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand:
+ with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand:
mock_rand.side_effect = [INVALID_SPAN_ID, 999]
span_id = generate_span_id()
assert span_id == 999
@@ -290,7 +289,7 @@ def test_convert_to_trace_id():
def test_convert_string_to_id():
assert convert_string_to_id("test") > 0
# Test with None string
- with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen:
+ with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen:
mock_gen.return_value = 12345
assert convert_string_to_id(None) == 12345
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py
similarity index 97%
rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py
index 2fcb927e0c..38d33dd21b 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py
@@ -1,11 +1,10 @@
import pytest
+from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import ValidationError
-from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata
-
class TestTraceMetadata:
def test_trace_metadata_init(self):
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py
similarity index 97%
rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py
index 3961555b9a..9cab40748f 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py
@@ -1,4 +1,4 @@
-from core.ops.aliyun_trace.entities.semconv import (
+from dify_trace_aliyun.entities.semconv import (
ACS_ARMS_SERVICE_FEATURE,
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py
similarity index 99%
rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py
index c2324fdec4..c1b11c9186 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py
@@ -4,12 +4,11 @@ from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
+import dify_trace_aliyun.aliyun_trace as aliyun_trace_module
import pytest
-from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
-
-import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module
-from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
-from core.ops.aliyun_trace.entities.semconv import (
+from dify_trace_aliyun.aliyun_trace import AliyunDataTrace
+from dify_trace_aliyun.config import AliyunConfig
+from dify_trace_aliyun.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_INPUT_MESSAGE,
GEN_AI_OUTPUT_MESSAGE,
@@ -24,7 +23,8 @@ from core.ops.aliyun_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
-from core.ops.entities.config_entity import AliyunConfig
+from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags
+
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py
similarity index 95%
rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py
rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py
index e4d8f2d5ea..a9e7b80c2a 100644
--- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py
@@ -1,9 +1,7 @@
import json
from unittest.mock import MagicMock
-from opentelemetry.trace import Link, StatusCode
-
-from core.ops.aliyun_trace.entities.semconv import (
+from dify_trace_aliyun.entities.semconv import (
GEN_AI_FRAMEWORK,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
@@ -11,7 +9,7 @@ from core.ops.aliyun_trace.entities.semconv import (
INPUT_VALUE,
OUTPUT_VALUE,
)
-from core.ops.aliyun_trace.utils import (
+from dify_trace_aliyun.utils import (
create_common_span_attributes,
create_links_from_trace_id,
create_status_from_error,
@@ -23,6 +21,8 @@ from core.ops.aliyun_trace.utils import (
get_workflow_node_status,
serialize_json_data,
)
+from opentelemetry.trace import Link, StatusCode
+
from core.rag.models.document import Document
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@@ -48,7 +48,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch):
mock_session = MagicMock()
mock_session.get.return_value = end_user_data
- from core.ops.aliyun_trace.utils import db
+ from dify_trace_aliyun.utils import db
monkeypatch.setattr(db, "session", mock_session)
@@ -63,7 +63,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch):
mock_session = MagicMock()
mock_session.get.return_value = None
- from core.ops.aliyun_trace.utils import db
+ from dify_trace_aliyun.utils import db
monkeypatch.setattr(db, "session", mock_session)
@@ -112,9 +112,9 @@ def test_get_workflow_node_status():
def test_create_links_from_trace_id(monkeypatch):
# Mock create_link
mock_link = MagicMock(spec=Link)
- import core.ops.aliyun_trace.data_exporter.traceclient
+ import dify_trace_aliyun.data_exporter.traceclient
- monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
+ monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link)
# Trace ID None
assert create_links_from_trace_id(None) == []
diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..1b24ee7421
--- /dev/null
+++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,85 @@
+import pytest
+from dify_trace_aliyun.config import AliyunConfig
+from pydantic import ValidationError
+
+
+class TestAliyunConfig:
+ """Test cases for AliyunConfig"""
+
+ def test_valid_config(self):
+ """Test valid Aliyun configuration"""
+ config = AliyunConfig(
+ app_name="test_app",
+ license_key="test_license_key",
+ endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
+ )
+ assert config.app_name == "test_app"
+ assert config.license_key == "test_license_key"
+ assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+ assert config.app_name == "dify_app"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ AliyunConfig()
+
+ with pytest.raises(ValidationError):
+ AliyunConfig(license_key="test_license")
+
+ with pytest.raises(ValidationError):
+ AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_app_name_validation_empty(self):
+ """Test app_name validation with empty value"""
+ config = AliyunConfig(
+ license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
+ )
+ assert config.app_name == "dify_app"
+
+ def test_endpoint_validation_empty(self):
+ """Test endpoint validation with empty value"""
+ config = AliyunConfig(license_key="test_license", endpoint="")
+ assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation preserves path for Aliyun endpoints"""
+ config = AliyunConfig(
+ license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
+ )
+ assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
+
+ def test_endpoint_validation_invalid_scheme(self):
+ """Test endpoint validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
+ AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_endpoint_validation_no_scheme(self):
+ """Test endpoint validation rejects URLs without scheme"""
+ with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
+ AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_license_key_required(self):
+ """Test that license_key is required and cannot be empty"""
+ with pytest.raises(ValidationError):
+ AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_valid_endpoint_format_examples(self):
+ """Test valid endpoint format examples from comments"""
+ valid_endpoints = [
+ # cms2.0 public endpoint
+ "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
+ # cms2.0 intranet endpoint
+ "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
+ # xtrace public endpoint
+ "http://tracing-cn-heyuan.arms.aliyuncs.com",
+ # xtrace intranet endpoint
+ "http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
+ ]
+
+ for endpoint in valid_endpoints:
+ config = AliyunConfig(license_key="test_license", endpoint=endpoint)
+ assert config.endpoint == endpoint
diff --git a/api/providers/trace/trace-arize-phoenix/pyproject.toml b/api/providers/trace/trace-arize-phoenix/pyproject.toml
new file mode 100644
index 0000000000..9e756944c9
--- /dev/null
+++ b/api/providers/trace/trace-arize-phoenix/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-arize-phoenix"
+version = "0.0.1"
+dependencies = [
+ "arize-phoenix-otel~=0.15.0",
+]
+description = "Dify ops tracing provider (Arize / Phoenix)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/langfuse_trace/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py
similarity index 100%
rename from api/core/ops/langfuse_trace/__init__.py
rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py
diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
similarity index 99%
rename from api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
index 78516e1a22..96df49ed0e 100644
--- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
+++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
@@ -25,7 +25,6 @@ from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -39,6 +38,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from models.model import EndUser, MessageFile
diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py
new file mode 100644
index 0000000000..6eac5b30d2
--- /dev/null
+++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py
@@ -0,0 +1,45 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url_with_path
+
+
+class ArizeConfig(BaseTracingConfig):
+ """
+ Model class for Arize tracing config.
+ """
+
+ api_key: str | None = None
+ space_id: str | None = None
+ project: str | None = None
+ endpoint: str = "https://otlp.arize.com"
+
+ @field_validator("project")
+ @classmethod
+ def project_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "default")
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return cls.validate_endpoint_url(v, "https://otlp.arize.com")
+
+
+class PhoenixConfig(BaseTracingConfig):
+ """
+ Model class for Phoenix tracing config.
+ """
+
+ api_key: str | None = None
+ project: str | None = None
+ endpoint: str = "https://app.phoenix.arize.com"
+
+ @field_validator("project")
+ @classmethod
+ def project_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "default")
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return validate_url_with_path(v, "https://app.phoenix.arize.com")
diff --git a/api/core/ops/langfuse_trace/entities/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed
similarity index 100%
rename from api/core/ops/langfuse_trace/entities/__init__.py
rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed
diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py
similarity index 91%
rename from api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py
rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py
index 4ce9e22fd7..b0691a87ea 100644
--- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py
+++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py
@@ -2,11 +2,7 @@ from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
-from opentelemetry.sdk.trace import Tracer
-from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
-from opentelemetry.trace import StatusCode
-
-from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
+from dify_trace_arize_phoenix.arize_phoenix_trace import (
ArizePhoenixDataTrace,
datetime_to_nanos,
error_to_string,
@@ -15,7 +11,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import (
setup_tracer,
wrap_span_metadata,
)
-from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
+from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
+from opentelemetry.sdk.trace import Tracer
+from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
+from opentelemetry.trace import StatusCode
+
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -80,7 +80,7 @@ def test_datetime_to_nanos():
expected = int(dt.timestamp() * 1_000_000_000)
assert datetime_to_nanos(dt) == expected
- with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt:
+ with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt:
mock_now = MagicMock()
mock_now.timestamp.return_value = 1704110400.0
mock_dt.now.return_value = mock_now
@@ -142,8 +142,8 @@ def test_wrap_span_metadata():
assert res == {"a": 1, "b": 2, "created_from": "Dify"}
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter")
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
def test_setup_tracer_arize(mock_provider, mock_exporter):
config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p")
setup_tracer(config)
@@ -151,8 +151,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter):
assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1"
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter")
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider")
def test_setup_tracer_phoenix(mock_provider, mock_exporter):
config = PhoenixConfig(endpoint="http://p.com", project="p")
setup_tracer(config)
@@ -162,7 +162,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter):
def test_setup_tracer_exception():
config = ArizeConfig(endpoint="http://a.com", project="p")
- with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
+ with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")):
with pytest.raises(Exception, match="boom"):
setup_tracer(config)
@@ -172,7 +172,7 @@ def test_setup_tracer_exception():
@pytest.fixture
def trace_instance():
- with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup:
+ with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup:
mock_tracer = MagicMock(spec=Tracer)
mock_processor = MagicMock()
mock_setup.return_value = (mock_tracer, mock_processor)
@@ -228,9 +228,9 @@ def test_trace_exception(trace_instance):
trace_instance.trace(_make_workflow_info())
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker")
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory")
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance):
mock_db.engine = MagicMock()
info = _make_workflow_info()
@@ -262,7 +262,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac
assert trace_instance.tracer.start_span.call_count >= 2
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_workflow_trace_no_app_id(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_workflow_info()
@@ -271,7 +271,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance):
trace_instance.workflow_trace(info)
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_message_trace_success(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_message_info()
@@ -291,7 +291,7 @@ def test_message_trace_success(mock_db, trace_instance):
assert trace_instance.tracer.start_span.call_count >= 1
-@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db")
+@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db")
def test_message_trace_with_error(mock_db, trace_instance):
mock_db.engine = MagicMock()
info = _make_message_info()
diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py
similarity index 94%
rename from api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py
rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py
index 4b925390d9..a01c63ae61 100644
--- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py
+++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py
@@ -1,6 +1,6 @@
+from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from openinference.semconv.trace import OpenInferenceSpanKindValues
-from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes
diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..11e951c3b1
--- /dev/null
+++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,88 @@
+import pytest
+from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
+from pydantic import ValidationError
+
+
+class TestArizeConfig:
+ """Test cases for ArizeConfig"""
+
+ def test_valid_config(self):
+ """Test valid Arize configuration"""
+ config = ArizeConfig(
+ api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
+ )
+ assert config.api_key == "test_key"
+ assert config.space_id == "test_space"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.arize.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = ArizeConfig()
+ assert config.api_key is None
+ assert config.space_id is None
+ assert config.project is None
+ assert config.endpoint == "https://otlp.arize.com"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = ArizeConfig(project="")
+ assert config.project == "default"
+
+ def test_project_validation_none(self):
+ """Test project validation with None value"""
+ config = ArizeConfig(project=None)
+ assert config.project == "default"
+
+ def test_endpoint_validation_empty(self):
+ """Test endpoint validation with empty value"""
+ config = ArizeConfig(endpoint="")
+ assert config.endpoint == "https://otlp.arize.com"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation normalizes URL by removing path"""
+ config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
+ assert config.endpoint == "https://custom.arize.com"
+
+ def test_endpoint_validation_invalid_scheme(self):
+ """Test endpoint validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ ArizeConfig(endpoint="ftp://invalid.com")
+
+ def test_endpoint_validation_no_scheme(self):
+ """Test endpoint validation rejects URLs without scheme"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ ArizeConfig(endpoint="invalid.com")
+
+
+class TestPhoenixConfig:
+ """Test cases for PhoenixConfig"""
+
+ def test_valid_config(self):
+ """Test valid Phoenix configuration"""
+ config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.phoenix.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = PhoenixConfig()
+ assert config.api_key is None
+ assert config.project is None
+ assert config.endpoint == "https://app.phoenix.arize.com"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = PhoenixConfig(project="")
+ assert config.project == "default"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation with path"""
+ config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
+ assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
+
+ def test_endpoint_validation_without_path(self):
+ """Test endpoint validation without path"""
+ config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
+ assert config.endpoint == "https://app.phoenix.arize.com"
diff --git a/api/providers/trace/trace-langfuse/pyproject.toml b/api/providers/trace/trace-langfuse/pyproject.toml
new file mode 100644
index 0000000000..27d2273a69
--- /dev/null
+++ b/api/providers/trace/trace-langfuse/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-langfuse"
+version = "0.0.1"
+dependencies = [
+ "langfuse>=4.2.0,<5.0.0",
+]
+description = "Dify ops tracing provider (Langfuse)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/langsmith_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py
similarity index 100%
rename from api/core/ops/langsmith_trace/__init__.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py
diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py
new file mode 100644
index 0000000000..90d1a2846b
--- /dev/null
+++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py
@@ -0,0 +1,19 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url_with_path
+
+
+class LangfuseConfig(BaseTracingConfig):
+ """
+ Model class for Langfuse tracing config.
+ """
+
+ public_key: str
+ secret_key: str
+ host: str = "https://api.langfuse.com"
+
+ @field_validator("host")
+ @classmethod
+ def host_validator(cls, v, info: ValidationInfo):
+ return validate_url_with_path(v, "https://api.langfuse.com")
diff --git a/api/core/ops/langsmith_trace/entities/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py
similarity index 100%
rename from api/core/ops/langsmith_trace/entities/__init__.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py
diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py
similarity index 100%
rename from api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py
diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py
similarity index 99%
rename from api/core/ops/langfuse_trace/langfuse_trace.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py
index 7eacc2be46..68881378a7 100644
--- a/api/core/ops/langfuse_trace/langfuse_trace.py
+++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py
@@ -16,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -28,7 +27,10 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
+from core.ops.utils import filter_none_values
+from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_langfuse.config import LangfuseConfig
+from dify_trace_langfuse.entities.langfuse_trace_entity import (
GenerationUsage,
LangfuseGeneration,
LangfuseSpan,
@@ -36,8 +38,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
LevelEnum,
UnitEnum,
)
-from core.ops.utils import filter_none_values
-from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
diff --git a/api/core/ops/mlflow_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed
similarity index 100%
rename from api/core/ops/mlflow_trace/__init__.py
rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed
diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py
similarity index 93%
rename from api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py
rename to api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py
index a0bcc92795..952f10c34f 100644
--- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py
+++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py
@@ -5,8 +5,16 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
+from dify_trace_langfuse.config import LangfuseConfig
+from dify_trace_langfuse.entities.langfuse_trace_entity import (
+ LangfuseGeneration,
+ LangfuseSpan,
+ LangfuseTrace,
+ LevelEnum,
+ UnitEnum,
+)
+from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
-from core.ops.entities.config_entity import LangfuseConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -17,14 +25,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
- LangfuseGeneration,
- LangfuseSpan,
- LangfuseTrace,
- LevelEnum,
- UnitEnum,
-)
-from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from graphon.enums import BuiltinNodeTypes
from models import EndUser
from models.enums import MessageStatus
@@ -43,7 +43,7 @@ def langfuse_config():
def trace_instance(langfuse_config, monkeypatch):
# Mock Langfuse client to avoid network calls
mock_client = MagicMock()
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client)
instance = LangFuseDataTrace(langfuse_config)
return instance
@@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch):
def test_init(langfuse_config, monkeypatch):
mock_langfuse = MagicMock()
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangFuseDataTrace(langfuse_config)
@@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
# Mock DB and Repositories
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
# Mock node executions
node_llm = MagicMock()
@@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
error="",
)
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
@@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
workflow_app_log_id="log-1",
error="",
)
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_trace = MagicMock()
trace_instance.add_generation = MagicMock()
@@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat
repo.get_by_workflow_execution.return_value = [node]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..103d888eef
--- /dev/null
+++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,42 @@
+import pytest
+from dify_trace_langfuse.config import LangfuseConfig
+from pydantic import ValidationError
+
+
+class TestLangfuseConfig:
+ """Test cases for LangfuseConfig"""
+
+ def test_valid_config(self):
+ """Test valid Langfuse configuration"""
+ config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
+ assert config.public_key == "public_key"
+ assert config.secret_key == "secret_key"
+ assert config.host == "https://custom.langfuse.com"
+
+ def test_valid_config_with_path(self):
+ host = "https://custom.langfuse.com/api/v1"
+ config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
+ assert config.public_key == "public_key"
+ assert config.secret_key == "secret_key"
+ assert config.host == host
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = LangfuseConfig(public_key="public", secret_key="secret")
+ assert config.host == "https://api.langfuse.com"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ LangfuseConfig()
+
+ with pytest.raises(ValidationError):
+ LangfuseConfig(public_key="public")
+
+ with pytest.raises(ValidationError):
+ LangfuseConfig(secret_key="secret")
+
+ def test_host_validation_empty(self):
+ """Test host validation with empty value"""
+ config = LangfuseConfig(public_key="public", secret_key="secret", host="")
+ assert config.host == "https://api.langfuse.com"
diff --git a/api/tests/unit_tests/core/ops/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py
similarity index 92%
rename from api/tests/unit_tests/core/ops/test_langfuse_trace.py
rename to api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py
index 017ac8c891..0340ffb669 100644
--- a/api/tests/unit_tests/core/ops/test_langfuse_trace.py
+++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py
@@ -4,14 +4,15 @@ from datetime import datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
-from core.ops.entities.config_entity import LangfuseConfig
+from dify_trace_langfuse.config import LangfuseConfig
+from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
+
from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo
-from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from graphon.enums import BuiltinNodeTypes
def _create_trace_instance() -> LangFuseDataTrace:
- with patch("core.ops.langfuse_trace.langfuse_trace.Langfuse", autospec=True):
+ with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True):
return LangFuseDataTrace(
LangfuseConfig(
public_key="public-key",
@@ -116,9 +117,9 @@ class TestLangFuseDataTraceCompletionStartTime:
patch.object(trace, "add_span"),
patch.object(trace, "add_generation") as add_generation,
patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()),
- patch("core.ops.langfuse_trace.langfuse_trace.db", MagicMock()),
+ patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()),
patch(
- "core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
+ "dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=repository,
),
):
diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml
new file mode 100644
index 0000000000..8131952b28
--- /dev/null
+++ b/api/providers/trace/trace-langsmith/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-langsmith"
+version = "0.0.1"
+dependencies = [
+ "langsmith~=0.7.30",
+]
+description = "Dify ops tracing provider (LangSmith)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/opik_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py
similarity index 100%
rename from api/core/ops/opik_trace/__init__.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py
diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py
new file mode 100644
index 0000000000..498b8c5e7e
--- /dev/null
+++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py
@@ -0,0 +1,20 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url
+
+
+class LangSmithConfig(BaseTracingConfig):
+ """
+ Model class for Langsmith tracing config.
+ """
+
+ api_key: str
+ project: str
+ endpoint: str = "https://api.smith.langchain.com"
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ # LangSmith only allows HTTPS
+ return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
diff --git a/api/core/ops/tencent_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py
similarity index 100%
rename from api/core/ops/tencent_trace/__init__.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py
diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py
similarity index 100%
rename from api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py
diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py
similarity index 99%
rename from api/core/ops/langsmith_trace/langsmith_trace.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py
index d960038f15..145bd70dbc 100644
--- a/api/core/ops/langsmith_trace/langsmith_trace.py
+++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py
@@ -9,7 +9,6 @@ from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -21,13 +20,14 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
+from core.ops.utils import filter_none_values, generate_dotted_order
+from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_langsmith.config import LangSmithConfig
+from dify_trace_langsmith.entities.langsmith_trace_entity import (
LangSmithRunModel,
LangSmithRunType,
LangSmithRunUpdateModel,
)
-from core.ops.utils import filter_none_values, generate_dotted_order
-from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
diff --git a/api/core/ops/weave_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed
similarity index 100%
rename from api/core/ops/weave_trace/__init__.py
rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed
diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py
similarity index 91%
rename from api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py
rename to api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py
index 34c64c54a1..45e5894e4a 100644
--- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py
+++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py
@@ -3,8 +3,14 @@ from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
+from dify_trace_langsmith.config import LangSmithConfig
+from dify_trace_langsmith.entities.langsmith_trace_entity import (
+ LangSmithRunModel,
+ LangSmithRunType,
+ LangSmithRunUpdateModel,
+)
+from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace
-from core.ops.entities.config_entity import LangSmithConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -15,12 +21,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
- LangSmithRunModel,
- LangSmithRunType,
- LangSmithRunUpdateModel,
-)
-from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser
@@ -38,7 +38,7 @@ def langsmith_config():
def trace_instance(langsmith_config, monkeypatch):
# Mock LangSmith client
mock_client = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client)
instance = LangSmithDataTrace(langsmith_config)
return instance
@@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch):
def test_init(langsmith_config, monkeypatch):
mock_client_class = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = LangSmithDataTrace(langsmith_config)
@@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch):
# Mock dependencies
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
# Mock node executions
node_llm = MagicMock()
@@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch):
)
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
@@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
trace_info.error = ""
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch):
# Mock EndUser lookup
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_run = MagicMock()
@@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..37efaf69cf
--- /dev/null
+++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,35 @@
+import pytest
+from dify_trace_langsmith.config import LangSmithConfig
+from pydantic import ValidationError
+
+
+class TestLangSmithConfig:
+ """Test cases for LangSmithConfig"""
+
+ def test_valid_config(self):
+ """Test valid LangSmith configuration"""
+ config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.smith.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = LangSmithConfig(api_key="key", project="project")
+ assert config.endpoint == "https://api.smith.langchain.com"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ LangSmithConfig()
+
+ with pytest.raises(ValidationError):
+ LangSmithConfig(api_key="key")
+
+ with pytest.raises(ValidationError):
+ LangSmithConfig(project="project")
+
+ def test_endpoint_validation_https_only(self):
+ """Test endpoint validation only allows HTTPS"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
diff --git a/api/providers/trace/trace-mlflow/pyproject.toml b/api/providers/trace/trace-mlflow/pyproject.toml
new file mode 100644
index 0000000000..fad6002944
--- /dev/null
+++ b/api/providers/trace/trace-mlflow/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-mlflow"
+version = "0.0.1"
+dependencies = [
+ "mlflow-skinny>=3.11.1",
+]
+description = "Dify ops tracing provider (MLflow / Databricks)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py
similarity index 100%
rename from api/core/ops/weave_trace/entities/__init__.py
rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py
diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py
new file mode 100644
index 0000000000..84914165e3
--- /dev/null
+++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py
@@ -0,0 +1,46 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_integer_id, validate_url_with_path
+
+
+class MLflowConfig(BaseTracingConfig):
+ """
+ Model class for MLflow tracing config.
+ """
+
+ tracking_uri: str = "http://localhost:5000"
+ experiment_id: str = "0" # Default experiment id in MLflow is 0
+ username: str | None = None
+ password: str | None = None
+
+ @field_validator("tracking_uri")
+ @classmethod
+ def tracking_uri_validator(cls, v, info: ValidationInfo):
+ if isinstance(v, str) and v.startswith("databricks"):
+ raise ValueError(
+ "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances."
+ )
+ return validate_url_with_path(v, "http://localhost:5000")
+
+ @field_validator("experiment_id")
+ @classmethod
+ def experiment_id_validator(cls, v, info: ValidationInfo):
+ return validate_integer_id(v)
+
+
+class DatabricksConfig(BaseTracingConfig):
+ """
+ Model class for Databricks (Databricks-managed MLflow) tracing config.
+ """
+
+ experiment_id: str
+ host: str
+ client_id: str | None = None
+ client_secret: str | None = None
+ personal_access_token: str | None = None
+
+ @field_validator("experiment_id")
+ @classmethod
+ def experiment_id_validator(cls, v, info: ValidationInfo):
+ return validate_integer_id(v)
diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
similarity index 99%
rename from api/core/ops/mlflow_trace/mlflow_trace.py
rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
index 87fcaeabcc..4e4c45a532 100644
--- a/api/core/ops/mlflow_trace/mlflow_trace.py
+++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
@@ -11,7 +11,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex
from sqlalchemy import select
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -24,6 +23,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
+from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes
from models import EndUser
diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py
similarity index 98%
rename from api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py
rename to api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py
index afc5726ede..20211456e3 100644
--- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py
+++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py
@@ -1,4 +1,4 @@
-"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module."""
+"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module."""
from __future__ import annotations
@@ -9,8 +9,9 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
+from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig
+from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
-from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -20,7 +21,6 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds
from graphon.enums import BuiltinNodeTypes
# ── Helpers ──────────────────────────────────────────────────────────────────
@@ -179,7 +179,7 @@ def _make_node(**overrides):
@pytest.fixture
def mock_mlflow():
- with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock:
+ with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock:
yield mock
@@ -187,10 +187,10 @@ def mock_mlflow():
def mock_tracing():
"""Patch all MLflow tracing functions used by the module."""
with (
- patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start,
- patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update,
- patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set,
- patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach,
+ patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start,
+ patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update,
+ patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set,
+ patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach,
):
yield {
"start": mock_start,
@@ -202,7 +202,7 @@ def mock_tracing():
@pytest.fixture
def mock_db():
- with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock:
+ with patch("dify_trace_mlflow.mlflow_trace.db") as mock:
yield mock
diff --git a/api/providers/trace/trace-opik/pyproject.toml b/api/providers/trace/trace-opik/pyproject.toml
new file mode 100644
index 0000000000..874997168e
--- /dev/null
+++ b/api/providers/trace/trace-opik/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-opik"
+version = "0.0.1"
+dependencies = [
+ "opik~=1.11.2",
+]
+description = "Dify ops tracing provider (Opik)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/config.py b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py
new file mode 100644
index 0000000000..c16ff1d903
--- /dev/null
+++ b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py
@@ -0,0 +1,25 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url_with_path
+
+
+class OpikConfig(BaseTracingConfig):
+ """
+ Model class for Opik tracing config.
+ """
+
+ api_key: str | None = None
+ project: str | None = None
+ workspace: str | None = None
+ url: str = "https://www.comet.com/opik/api/"
+
+ @field_validator("project")
+ @classmethod
+ def project_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "Default Project")
+
+ @field_validator("url")
+ @classmethod
+ def url_validator(cls, v, info: ValidationInfo):
+ return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py
similarity index 99%
rename from api/core/ops/opik_trace/opik_trace.py
rename to api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py
index 672efe45bd..2d124ac989 100644
--- a/api/core/ops/opik_trace/opik_trace.py
+++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py
@@ -10,7 +10,6 @@ from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -23,6 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_opik.config import OpikConfig
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed b/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py
similarity index 93%
rename from api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py
rename to api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py
index c02ac413f2..eefed3c78c 100644
--- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py
+++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py
@@ -5,8 +5,9 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
+from dify_trace_opik.config import OpikConfig
+from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
-from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -17,7 +18,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser
from models.enums import MessageStatus
@@ -37,7 +37,7 @@ def opik_config():
@pytest.fixture
def trace_instance(opik_config, monkeypatch):
mock_client = MagicMock()
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client)
instance = OpikDataTrace(opik_config)
return instance
@@ -67,7 +67,7 @@ def test_prepare_opik_uuid():
def test_init(opik_config, monkeypatch):
mock_opik = MagicMock()
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik)
monkeypatch.setenv("FILES_URL", "http://test.url")
instance = OpikDataTrace(opik_config)
@@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
)
mock_session = MagicMock()
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
node_llm = MagicMock()
node_llm.id = LLM_NODE_ID
@@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch):
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
@@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch):
error="",
)
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
@@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch):
workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e",
error="",
)
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
with pytest.raises(ValueError, match="No app_id found in trace_info metadata"):
trace_instance.workflow_trace(trace_info)
@@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch):
mock_end_user = MagicMock(spec=EndUser)
mock_end_user.session_id = "session-id-123"
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user)
trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2"))
trace_instance.add_span = MagicMock()
@@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch
repo.get_by_workflow_execution.return_value = [node]
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory)
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
- monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock())
+ monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine"))
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_trace = MagicMock()
diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..5a54b70bba
--- /dev/null
+++ b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,48 @@
+import pytest
+from dify_trace_opik.config import OpikConfig
+from pydantic import ValidationError
+
+
+class TestOpikConfig:
+ """Test cases for OpikConfig"""
+
+ def test_valid_config(self):
+ """Test valid Opik configuration"""
+ config = OpikConfig(
+ api_key="test_key",
+ project="test_project",
+ workspace="test_workspace",
+ url="https://custom.comet.com/opik/api/",
+ )
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.workspace == "test_workspace"
+ assert config.url == "https://custom.comet.com/opik/api/"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = OpikConfig()
+ assert config.api_key is None
+ assert config.project is None
+ assert config.workspace is None
+ assert config.url == "https://www.comet.com/opik/api/"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = OpikConfig(project="")
+ assert config.project == "Default Project"
+
+ def test_url_validation_empty(self):
+ """Test URL validation with empty value"""
+ config = OpikConfig(url="")
+ assert config.url == "https://www.comet.com/opik/api/"
+
+ def test_url_validation_missing_suffix(self):
+ """Test URL validation requires /api/ suffix"""
+ with pytest.raises(ValidationError, match="URL should end with /api/"):
+ OpikConfig(url="https://custom.comet.com/opik/")
+
+ def test_url_validation_invalid_scheme(self):
+ """Test URL validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
+ OpikConfig(url="ftp://custom.comet.com/opik/api/")
diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py
similarity index 94%
rename from api/tests/unit_tests/core/ops/test_opik_trace.py
rename to api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py
index ad9d0846be..fba290f5b8 100644
--- a/api/tests/unit_tests/core/ops/test_opik_trace.py
+++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py
@@ -14,8 +14,9 @@ import uuid
from datetime import datetime
from unittest.mock import MagicMock, patch
+from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
+
from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo
-from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid
# A stable UUID4 used as the workflow_run_id throughout all tests.
_WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6"
@@ -56,8 +57,8 @@ def _make_workflow_trace_info(
def _make_opik_trace_instance() -> OpikDataTrace:
"""Construct an OpikDataTrace with the Opik SDK client mocked out."""
- with patch("core.ops.opik_trace.opik_trace.Opik"):
- from core.ops.entities.config_entity import OpikConfig
+ with patch("dify_trace_opik.opik_trace.Opik"):
+ from dify_trace_opik.config import OpikConfig
config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/")
instance = OpikDataTrace(config)
@@ -133,10 +134,10 @@ class TestWorkflowTraceWithoutMessageId:
fake_repo.get_by_workflow_execution.return_value = node_executions or []
with (
- patch("core.ops.opik_trace.opik_trace.db") as mock_db,
- patch("core.ops.opik_trace.opik_trace.sessionmaker"),
+ patch("dify_trace_opik.opik_trace.db") as mock_db,
+ patch("dify_trace_opik.opik_trace.sessionmaker"),
patch(
- "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
+ "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=fake_repo,
),
):
@@ -265,10 +266,10 @@ class TestWorkflowTraceWithMessageId:
fake_repo.get_by_workflow_execution.return_value = node_executions or []
with (
- patch("core.ops.opik_trace.opik_trace.db") as mock_db,
- patch("core.ops.opik_trace.opik_trace.sessionmaker"),
+ patch("dify_trace_opik.opik_trace.db") as mock_db,
+ patch("dify_trace_opik.opik_trace.sessionmaker"),
patch(
- "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
+ "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=fake_repo,
),
):
diff --git a/api/providers/trace/trace-tencent/pyproject.toml b/api/providers/trace/trace-tencent/pyproject.toml
new file mode 100644
index 0000000000..eab06fc708
--- /dev/null
+++ b/api/providers/trace/trace-tencent/pyproject.toml
@@ -0,0 +1,14 @@
+[project]
+name = "dify-trace-tencent"
+version = "0.0.1"
+dependencies = [
+ # versions inherited from parent
+ "opentelemetry-api",
+ "opentelemetry-exporter-otlp-proto-grpc",
+ "opentelemetry-sdk",
+ "opentelemetry-semantic-conventions",
+]
+description = "Dify ops tracing provider (Tencent APM)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/tencent_trace/client.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py
similarity index 100%
rename from api/core/ops/tencent_trace/client.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py
diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py
new file mode 100644
index 0000000000..398e6c55a8
--- /dev/null
+++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py
@@ -0,0 +1,30 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+
+
+class TencentConfig(BaseTracingConfig):
+ """
+ Tencent APM tracing config
+ """
+
+ token: str
+ endpoint: str
+ service_name: str
+
+ @field_validator("token")
+ @classmethod
+ def token_validator(cls, v, info: ValidationInfo):
+ if not v or v.strip() == "":
+ raise ValueError("Token cannot be empty")
+ return v
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com")
+
+ @field_validator("service_name")
+ @classmethod
+ def service_name_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "dify_app")
diff --git a/api/core/ops/tencent_trace/entities/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py
similarity index 100%
rename from api/core/ops/tencent_trace/entities/__init__.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py
diff --git a/api/core/ops/tencent_trace/entities/semconv.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py
similarity index 100%
rename from api/core/ops/tencent_trace/entities/semconv.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py
diff --git a/api/core/ops/tencent_trace/entities/tencent_trace_entity.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py
similarity index 100%
rename from api/core/ops/tencent_trace/entities/tencent_trace_entity.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py
diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed b/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py
similarity index 98%
rename from api/core/ops/tencent_trace/span_builder.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py
index 36878dc58f..763a85ffd7 100644
--- a/api/core/ops/tencent_trace/span_builder.py
+++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py
@@ -14,7 +14,8 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.ops.tencent_trace.entities.semconv import (
+from core.rag.models.document import Document
+from dify_trace_tencent.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_IS_ENTRY,
@@ -38,9 +39,8 @@ from core.ops.tencent_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
-from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
-from core.ops.tencent_trace.utils import TencentTraceUtils
-from core.rag.models.document import Document
+from dify_trace_tencent.entities.tencent_trace_entity import SpanData
+from dify_trace_tencent.utils import TencentTraceUtils
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
similarity index 94%
rename from api/core/ops/tencent_trace/tencent_trace.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
index d681b9da80..a8c480e4a5 100644
--- a/api/core/ops/tencent_trace/tencent_trace.py
+++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py
@@ -1,14 +1,12 @@
-"""
-Tencent APM tracing implementation with separated concerns
-"""
+"""Tencent APM tracing with idempotent client cleanup."""
+import inspect
import logging
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -19,11 +17,12 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.ops.tencent_trace.client import TencentTraceClient
-from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
-from core.ops.tencent_trace.span_builder import TencentSpanBuilder
-from core.ops.tencent_trace.utils import TencentTraceUtils
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from dify_trace_tencent.client import TencentTraceClient
+from dify_trace_tencent.config import TencentConfig
+from dify_trace_tencent.entities.tencent_trace_entity import SpanData
+from dify_trace_tencent.span_builder import TencentSpanBuilder
+from dify_trace_tencent.utils import TencentTraceUtils
from extensions.ext_database import db
from graphon.entities.workflow_node_execution import (
WorkflowNodeExecution,
@@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance):
"""
Tencent APM trace implementation with single responsibility principle.
Acts as a coordinator that delegates specific tasks to specialized classes.
+
+ The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
+ explicitly in tests or implicitly during garbage collection, so shutdown
+ must be safe to call multiple times.
"""
+ trace_client: TencentTraceClient
+ _closed: bool
+
def __init__(self, tencent_config: TencentConfig):
super().__init__(tencent_config)
+ self._closed = False
self.trace_client = TencentTraceClient(
service_name=tencent_config.service_name,
endpoint=tencent_config.endpoint,
@@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance):
except Exception:
logger.debug("[Tencent APM] Failed to record message trace duration")
- def __del__(self):
- """Ensure proper cleanup on garbage collection."""
+ def close(self) -> None:
+ """Synchronously and idempotently shutdown the underlying trace client."""
+ if getattr(self, "_closed", False):
+ return
+
+ self._closed = True
+ trace_client = getattr(self, "trace_client", None)
+ if trace_client is None:
+ return
+
try:
- if hasattr(self, "trace_client"):
- self.trace_client.shutdown()
+ shutdown_result = trace_client.shutdown()
+ if inspect.isawaitable(shutdown_result):
+ close_awaitable = getattr(shutdown_result, "close", None)
+ if callable(close_awaitable):
+ close_awaitable()
except Exception:
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
+
+ def __del__(self):
+ """Ensure best-effort cleanup on garbage collection without retrying shutdown."""
+ self.close()
diff --git a/api/core/ops/tencent_trace/utils.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py
similarity index 100%
rename from api/core/ops/tencent_trace/utils.py
rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py
diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py
similarity index 98%
rename from api/tests/unit_tests/core/ops/tencent_trace/test_client.py
rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py
index 870c18e53e..1e656e2462 100644
--- a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py
@@ -8,13 +8,12 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
+from dify_trace_tencent import client as client_module
+from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version
+from dify_trace_tencent.entities.tencent_trace_entity import SpanData
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
-from core.ops.tencent_trace import client as client_module
-from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version
-from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData
-
metric_reader_instances: list[DummyMetricReader] = []
meter_provider_instances: list[DummyMeterProvider] = []
diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py
similarity index 89%
rename from api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py
rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py
index 6113e5c6c8..e850a801f3 100644
--- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py
@@ -1,15 +1,7 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
-from opentelemetry.trace import StatusCode
-
-from core.ops.entities.trace_entity import (
- DatasetRetrievalTraceInfo,
- MessageTraceInfo,
- ToolTraceInfo,
- WorkflowTraceInfo,
-)
-from core.ops.tencent_trace.entities.semconv import (
+from dify_trace_tencent.entities.semconv import (
GEN_AI_IS_ENTRY,
GEN_AI_IS_STREAMING_REQUEST,
GEN_AI_MODEL_NAME,
@@ -23,7 +15,15 @@ from core.ops.tencent_trace.entities.semconv import (
TOOL_PARAMETERS,
GenAISpanKind,
)
-from core.ops.tencent_trace.span_builder import TencentSpanBuilder
+from dify_trace_tencent.span_builder import TencentSpanBuilder
+from opentelemetry.trace import StatusCode
+
+from core.ops.entities.trace_entity import (
+ DatasetRetrievalTraceInfo,
+ MessageTraceInfo,
+ ToolTraceInfo,
+ WorkflowTraceInfo,
+)
from core.rag.models.document import Document
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -31,7 +31,7 @@ from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutio
class TestTencentSpanBuilder:
def test_get_time_nanoseconds(self):
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert:
mock_convert.return_value = 123456789
dt = datetime.now()
result = TencentSpanBuilder._get_time_nanoseconds(dt)
@@ -48,7 +48,7 @@ class TestTencentSpanBuilder:
trace_info.workflow_run_outputs = {"answer": "world"}
trace_info.metadata = {"conversation_id": "conv_id"}
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
@@ -70,7 +70,7 @@ class TestTencentSpanBuilder:
trace_info.workflow_run_outputs = {}
trace_info.metadata = {} # No conversation_id
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 1
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1")
@@ -98,7 +98,7 @@ class TestTencentSpanBuilder:
}
node_execution.outputs = {"text": "world"}
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 456
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
@@ -123,7 +123,7 @@ class TestTencentSpanBuilder:
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
}
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 456
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution)
@@ -142,7 +142,7 @@ class TestTencentSpanBuilder:
trace_info.metadata = {"conversation_id": "conv_id"}
trace_info.is_streaming_request = True
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 789
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
@@ -162,7 +162,7 @@ class TestTencentSpanBuilder:
trace_info.metadata = {}
trace_info.is_streaming_request = False
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 789
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1")
@@ -182,7 +182,7 @@ class TestTencentSpanBuilder:
trace_info.tool_inputs = {"i": 2}
trace_info.tool_outputs = "result"
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 101
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1)
@@ -204,7 +204,7 @@ class TestTencentSpanBuilder:
)
trace_info.documents = [doc]
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 202
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
@@ -222,7 +222,7 @@ class TestTencentSpanBuilder:
trace_info.end_time = datetime.now()
trace_info.documents = []
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 202
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1)
@@ -264,7 +264,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 303
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
@@ -286,7 +286,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 303
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution)
@@ -307,7 +307,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 404
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
@@ -329,7 +329,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 404
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution)
@@ -350,7 +350,7 @@ class TestTencentSpanBuilder:
node_execution.created_at = datetime.now()
node_execution.finished_at = datetime.now()
- with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
+ with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id:
mock_convert_id.return_value = 505
with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100):
span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution)
diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
similarity index 86%
rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py
rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
index 7afd0b824a..54524b09ca 100644
--- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py
@@ -1,9 +1,12 @@
+import gc
import logging
-from unittest.mock import MagicMock, patch
+import warnings
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
+from dify_trace_tencent.config import TencentConfig
+from dify_trace_tencent.tencent_trace import TencentDataTrace
-from core.ops.entities.config_entity import TencentConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -13,7 +16,6 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
-from core.ops.tencent_trace.tencent_trace import TencentDataTrace
from graphon.entities import WorkflowNodeExecution
from graphon.enums import BuiltinNodeTypes
from models import Account, App, TenantAccountJoin
@@ -28,19 +30,19 @@ def tencent_config():
@pytest.fixture
def mock_trace_client():
- with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock:
+ with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock:
yield mock
@pytest.fixture
def mock_span_builder():
- with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock:
+ with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock:
yield mock
@pytest.fixture
def mock_trace_utils():
- with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock:
+ with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock:
yield mock
@@ -198,9 +200,9 @@ class TestTencentDataTrace:
trace_info.workflow_run_id = "run-id"
with patch(
- "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
+ "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.workflow_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
@@ -230,9 +232,9 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
with patch(
- "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
+ "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.message_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
@@ -262,9 +264,9 @@ class TestTencentDataTrace:
trace_info.message_id = "msg-id"
with patch(
- "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
+ "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.tool_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
@@ -294,22 +296,22 @@ class TestTencentDataTrace:
trace_info.message_id = "msg-id"
with patch(
- "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
+ "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.dataset_retrieval_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
def test_suggested_question_trace(self, tencent_data_trace):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
- with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
def test_suggested_question_trace_exception(self, tencent_data_trace):
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
- with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")):
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace.suggested_question_trace(trace_info)
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
@@ -342,7 +344,7 @@ class TestTencentDataTrace:
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
# The exception should be caught by the outer handler since convert_to_span_id is called first
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
@@ -351,7 +353,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
tencent_data_trace._process_workflow_nodes(trace_info, 123)
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
@@ -381,7 +383,7 @@ class TestTencentDataTrace:
node.id = "n1"
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
assert result is None
mock_log.assert_called_once()
@@ -403,15 +405,13 @@ class TestTencentDataTrace:
mock_executions = [MagicMock()]
- with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
+ with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.engine = "engine"
- with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
+ with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.side_effect = [app, account, tenant_join]
- with patch(
- "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
- ) as mock_repo:
+ with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo:
mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions
results = tencent_data_trace._get_workflow_node_executions(trace_info)
@@ -423,7 +423,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {}
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
@@ -432,14 +432,14 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.metadata = {"app_id": "app-1"}
- with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
+ with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock() # Ensure init_app is mocked
mock_db.engine = "engine"
- with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
+ with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.return_value = None
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
results = tencent_data_trace._get_workflow_node_executions(trace_info)
assert results == []
mock_log.assert_called_once()
@@ -449,8 +449,8 @@ class TestTencentDataTrace:
trace_info.tenant_id = "tenant-1"
trace_info.metadata = {"user_id": "user-1"}
- with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
- with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db:
+ with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
+ with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
mock_db.init_app = MagicMock()
mock_db.engine = MagicMock()
@@ -476,8 +476,8 @@ class TestTencentDataTrace:
trace_info.tenant_id = "t"
trace_info.metadata = {"user_id": "u"}
- with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")):
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")):
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
user_id = tencent_data_trace._get_user_id(trace_info)
assert user_id == "unknown"
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
@@ -519,7 +519,7 @@ class TestTencentDataTrace:
node.process_data = None
node.outputs = None
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_llm_metrics(node)
# Should not crash
@@ -557,7 +557,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.metadata = None
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_llm_metrics(trace_info)
# Should not crash
@@ -609,7 +609,7 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=WorkflowTraceInfo)
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_workflow_trace_duration(trace_info)
def test_record_message_trace_duration(self, tencent_data_trace):
@@ -631,16 +631,41 @@ class TestTencentDataTrace:
trace_info = MagicMock(spec=MessageTraceInfo)
trace_info.start_time = None
- with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log:
+ with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
tencent_data_trace._record_message_trace_duration(trace_info)
- def test_del(self, tencent_data_trace):
+ def test_close(self, tencent_data_trace):
client = tencent_data_trace.trace_client
- tencent_data_trace.__del__()
+ tencent_data_trace.close()
client.shutdown.assert_called_once()
- def test_del_exception(self, tencent_data_trace):
+ def test_close_is_idempotent(self, tencent_data_trace):
+ client = tencent_data_trace.trace_client
+
+ tencent_data_trace.close()
+ tencent_data_trace.close()
+
+ client.shutdown.assert_called_once()
+
+ def test_close_exception(self, tencent_data_trace):
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
- with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log:
- tencent_data_trace.__del__()
+ with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
+ tencent_data_trace.close()
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
+
+ def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
+ shutdown = AsyncMock()
+ tencent_data_trace.trace_client.shutdown = shutdown
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ tencent_data_trace.close()
+ gc.collect()
+
+ shutdown.assert_called_once()
+ assert not [
+ warning
+ for warning in caught
+ if issubclass(warning.category, RuntimeWarning)
+ and "AsyncMockMixin._execute_mock_call" in str(warning.message)
+ ]
diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py
similarity index 88%
rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py
rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py
index ef28d18e20..63c6d680d7 100644
--- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py
+++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py
@@ -8,10 +8,9 @@ from datetime import UTC, datetime
from unittest.mock import patch
import pytest
+from dify_trace_tencent.utils import TencentTraceUtils
from opentelemetry.trace import Link, TraceFlags
-from core.ops.tencent_trace.utils import TencentTraceUtils
-
def test_convert_to_trace_id_with_valid_uuid() -> None:
uuid_str = "12345678-1234-5678-1234-567812345678"
@@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None:
def test_convert_to_trace_id_uses_uuid4_when_none() -> None:
expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
- with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
+ with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int
uuid4_mock.assert_called_once()
@@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None:
def test_convert_to_span_id_uses_uuid4_when_none() -> None:
expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")
- with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
+ with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock:
span_id = TencentTraceUtils.convert_to_span_id(None, "workflow")
assert isinstance(span_id, int)
uuid4_mock.assert_called_once()
@@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None:
def test_generate_span_id_skips_invalid_span_id() -> None:
with patch(
- "core.ops.tencent_trace.utils.random.getrandbits",
+ "dify_trace_tencent.utils.random.getrandbits",
side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42],
) as bits_mock:
assert TencentTraceUtils.generate_span_id() == 42
@@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None:
fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC)
expected = int(fixed.timestamp() * 1e9)
- with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock:
+ with patch("dify_trace_tencent.utils.datetime") as datetime_mock:
datetime_mock.now.return_value = fixed
assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected
datetime_mock.now.assert_called_once()
@@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i
@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None])
def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None:
fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd")
- with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
+ with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock:
link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type]
assert link.context.trace_id == fallback_uuid.int
uuid4_mock.assert_called_once()
diff --git a/api/providers/trace/trace-weave/pyproject.toml b/api/providers/trace/trace-weave/pyproject.toml
new file mode 100644
index 0000000000..ba449f2a93
--- /dev/null
+++ b/api/providers/trace/trace-weave/pyproject.toml
@@ -0,0 +1,10 @@
+[project]
+name = "dify-trace-weave"
+version = "0.0.1"
+dependencies = [
+ "weave>=0.52.36",
+]
+description = "Dify ops tracing provider (Weave)."
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/config.py b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py
new file mode 100644
index 0000000000..5942bd57fe
--- /dev/null
+++ b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py
@@ -0,0 +1,29 @@
+from pydantic import ValidationInfo, field_validator
+
+from core.ops.entities.config_entity import BaseTracingConfig
+from core.ops.utils import validate_url
+
+
+class WeaveConfig(BaseTracingConfig):
+ """
+ Model class for Weave tracing config.
+ """
+
+ api_key: str
+ entity: str | None = None
+ project: str
+ endpoint: str = "https://trace.wandb.ai"
+ host: str | None = None
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ # Weave only allows HTTPS for endpoint
+ return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
+
+ @field_validator("host")
+ @classmethod
+ def host_validator(cls, v, info: ValidationInfo):
+ if v is not None and v.strip() != "":
+ return validate_url(v, v, allowed_schemes=("https", "http"))
+ return v
diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py
similarity index 100%
rename from api/core/ops/weave_trace/entities/weave_trace_entity.py
rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py
diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed b/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py
similarity index 99%
rename from api/core/ops/weave_trace/weave_trace.py
rename to api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py
index f79544f1c7..4292cbf0f1 100644
--- a/api/core/ops/weave_trace/weave_trace.py
+++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py
@@ -17,7 +17,6 @@ from weave.trace_server.trace_server_interface import (
)
from core.ops.base_trace_instance import BaseTraceInstance
-from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
@@ -29,8 +28,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
+from dify_trace_weave.config import WeaveConfig
+from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel
from extensions.ext_database import db
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py
new file mode 100644
index 0000000000..eeb1fe1d87
--- /dev/null
+++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py
@@ -0,0 +1,61 @@
+import pytest
+from dify_trace_weave.config import WeaveConfig
+from pydantic import ValidationError
+
+
+class TestWeaveConfig:
+ """Test cases for WeaveConfig"""
+
+ def test_valid_config(self):
+ """Test valid Weave configuration"""
+ config = WeaveConfig(
+ api_key="test_key",
+ entity="test_entity",
+ project="test_project",
+ endpoint="https://custom.wandb.ai",
+ host="https://custom.host.com",
+ )
+ assert config.api_key == "test_key"
+ assert config.entity == "test_entity"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.wandb.ai"
+ assert config.host == "https://custom.host.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = WeaveConfig(api_key="key", project="project")
+ assert config.entity is None
+ assert config.endpoint == "https://trace.wandb.ai"
+ assert config.host is None
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ WeaveConfig()
+
+ with pytest.raises(ValidationError):
+ WeaveConfig(api_key="key")
+
+ with pytest.raises(ValidationError):
+ WeaveConfig(project="project")
+
+ def test_endpoint_validation_https_only(self):
+ """Test endpoint validation only allows HTTPS"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
+
+ def test_host_validation_optional(self):
+ """Test host validation is optional but validates when provided"""
+ config = WeaveConfig(api_key="key", project="project", host=None)
+ assert config.host is None
+
+ config = WeaveConfig(api_key="key", project="project", host="")
+ assert config.host == ""
+
+ config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
+ assert config.host == "https://valid.host.com"
+
+ def test_host_validation_invalid_scheme(self):
+ """Test host validation rejects invalid schemes when provided"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py
similarity index 97%
rename from api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py
rename to api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py
index 531c7de05f..6028d0c550 100644
--- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py
+++ b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py
@@ -1,4 +1,4 @@
-"""Comprehensive tests for core.ops.weave_trace.weave_trace module."""
+"""Comprehensive tests for dify_trace_weave.weave_trace module."""
from __future__ import annotations
@@ -7,9 +7,11 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
+from dify_trace_weave.config import WeaveConfig
+from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel
+from dify_trace_weave.weave_trace import WeaveDataTrace
from weave.trace_server.trace_server_interface import TraceStatus
-from core.ops.entities.config_entity import WeaveConfig
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -20,8 +22,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
-from core.ops.weave_trace.weave_trace import WeaveDataTrace
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
# ── Helpers ──────────────────────────────────────────────────────────────────
@@ -191,14 +191,14 @@ def _make_node(**overrides):
@pytest.fixture
def mock_wandb():
- with patch("core.ops.weave_trace.weave_trace.wandb") as mock:
+ with patch("dify_trace_weave.weave_trace.wandb") as mock:
mock.login.return_value = True
yield mock
@pytest.fixture
def mock_weave():
- with patch("core.ops.weave_trace.weave_trace.weave") as mock:
+ with patch("dify_trace_weave.weave_trace.weave") as mock:
client = MagicMock()
client.entity = "my-entity"
client.project = "my-project"
@@ -307,7 +307,7 @@ class TestGetProjectUrl:
monkeypatch.setattr(trace_instance, "entity", None)
monkeypatch.setattr(trace_instance, "project_name", None)
# Force an error by making string formatting fail
- with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger:
+ with patch("dify_trace_weave.weave_trace.logger") as mock_logger:
# Simulate exception via property
original_entity = trace_instance.entity
trace_instance.entity = None
@@ -594,9 +594,9 @@ class TestWorkflowTrace:
mock_factory = MagicMock()
mock_factory.create_workflow_node_execution_repository.return_value = repo
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory)
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock())
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
return repo
def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch):
@@ -703,8 +703,8 @@ class TestWorkflowTrace:
def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch):
"""Raises ValueError when app_id is missing from metadata."""
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock())
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine"))
+ monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock())
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine"))
trace_info = _make_workflow_trace_info(
message_id=None,
@@ -802,7 +802,7 @@ class TestMessageTrace:
def test_basic_message_trace(self, trace_instance, monkeypatch):
"""message_trace creates message run and llm child run."""
monkeypatch.setattr(
- "core.ops.weave_trace.weave_trace.db.session.get",
+ "dify_trace_weave.weave_trace.db.session.get",
lambda model, pk: None,
)
@@ -824,7 +824,7 @@ class TestMessageTrace:
mock_db = MagicMock()
mock_db.session.get.return_value = None
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -846,7 +846,7 @@ class TestMessageTrace:
mock_db = MagicMock()
mock_db.session.get.return_value = end_user
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -866,7 +866,7 @@ class TestMessageTrace:
"""message_trace handles when from_end_user_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -884,7 +884,7 @@ class TestMessageTrace:
"""trace_id falls back to message_id when trace_id is None."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
@@ -899,7 +899,7 @@ class TestMessageTrace:
"""message_trace handles file_list=None gracefully."""
mock_db = MagicMock()
mock_db.session.get.return_value = None
- monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db)
+ monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db)
trace_instance.start_call = MagicMock()
trace_instance.finish_call = MagicMock()
diff --git a/api/pyproject.toml b/api/pyproject.toml
index a1ceea181e..8f6ee796ab 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -32,9 +32,6 @@ dependencies = [
"flask-restx>=1.3.2,<2.0.0",
"google-cloud-aiplatform>=1.147.0,<2.0.0",
"httpx[socks]>=0.28.1,<1.0.0",
- "langfuse>=4.2.0,<5.0.0",
- "langsmith>=0.7.31,<1.0.0",
- "mlflow-skinny>=3.11.1,<4.0.0",
"opentelemetry-distro>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
"opentelemetry-instrumentation-flask>=0.62b0,<1.0.0",
@@ -44,15 +41,12 @@ dependencies = [
"opentelemetry-propagator-b3>=1.41.0,<2.0.0",
"readabilipy>=0.3.0,<1.0.0",
"resend>=2.27.0,<3.0.0",
- "weave>=0.52.36,<1.0.0",
# Emerging: newer and fast-moving, use compatible pins
- "arize-phoenix-otel~=0.15.0",
"fastopenapi[flask]~=0.7.0",
- "graphon~=0.1.2",
+ "graphon~=0.2.2",
"httpx-sse~=0.4.0",
"json-repair~=0.59.2",
- "opik~=1.11.2",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@@ -61,8 +55,8 @@ dependencies = [
packages = []
[tool.uv.workspace]
-members = ["providers/vdb/*"]
-exclude = ["providers/vdb/__pycache__"]
+members = ["providers/vdb/*", "providers/trace/*"]
+exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"]
[tool.uv.sources]
dify-vdb-alibabacloud-mysql = { workspace = true }
@@ -95,9 +89,17 @@ dify-vdb-upstash = { workspace = true }
dify-vdb-vastbase = { workspace = true }
dify-vdb-vikingdb = { workspace = true }
dify-vdb-weaviate = { workspace = true }
+dify-trace-aliyun = { workspace = true }
+dify-trace-arize-phoenix = { workspace = true }
+dify-trace-langfuse = { workspace = true }
+dify-trace-langsmith = { workspace = true }
+dify-trace-mlflow = { workspace = true }
+dify-trace-opik = { workspace = true }
+dify-trace-tencent = { workspace = true }
+dify-trace-weave = { workspace = true }
[tool.uv]
-default-groups = ["storage", "tools", "vdb-all"]
+default-groups = ["storage", "tools", "vdb-all", "trace-all"]
package = false
override-dependencies = [
"pyarrow>=18.0.0",
@@ -266,6 +268,25 @@ vdb-weaviate = ["dify-vdb-weaviate"]
# Optional client used by some tests / integrations (not a vector backend plugin)
vdb-xinference = ["xinference-client>=2.4.0"]
+trace-all = [
+ "dify-trace-aliyun",
+ "dify-trace-arize-phoenix",
+ "dify-trace-langfuse",
+ "dify-trace-langsmith",
+ "dify-trace-mlflow",
+ "dify-trace-opik",
+ "dify-trace-tencent",
+ "dify-trace-weave",
+]
+trace-aliyun = ["dify-trace-aliyun"]
+trace-arize-phoenix = ["dify-trace-arize-phoenix"]
+trace-langfuse = ["dify-trace-langfuse"]
+trace-langsmith = ["dify-trace-langsmith"]
+trace-mlflow = ["dify-trace-mlflow"]
+trace-opik = ["dify-trace-opik"]
+trace-tencent = ["dify-trace-tencent"]
+trace-weave = ["dify-trace-weave"]
+
[tool.pyrefly]
project-includes = ["."]
project-excludes = [".venv", "migrations/"]
diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt
index 3e5ece1fcf..fbbca24558 100644
--- a/api/pyrefly-local-excludes.txt
+++ b/api/pyrefly-local-excludes.txt
@@ -34,12 +34,12 @@ core/external_data_tool/api/api.py
core/llm_generator/llm_generator.py
core/llm_generator/output_parser/structured_output.py
core/mcp/mcp_client.py
-core/ops/aliyun_trace/data_exporter/traceclient.py
-core/ops/arize_phoenix_trace/arize_phoenix_trace.py
-core/ops/mlflow_trace/mlflow_trace.py
+providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py
+providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py
+providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
core/ops/ops_trace_manager.py
-core/ops/tencent_trace/client.py
-core/ops/tencent_trace/utils.py
+providers/trace/trace-tencent/src/dify_trace_tencent/client.py
+providers/trace/trace-tencent/src/dify_trace_tencent/utils.py
core/plugin/backwards_invocation/base.py
core/plugin/backwards_invocation/model.py
core/prompt/utils/extract_thread_messages.py
diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json
index c4582e891d..ac0e2a3a53 100644
--- a/api/pyrightconfig.json
+++ b/api/pyrightconfig.json
@@ -5,7 +5,8 @@
".venv",
"migrations/",
"core/rag",
- "providers/",
+ "providers/vdb/",
+ "providers/trace/*/tests",
],
"typeCheckingMode": "strict",
"allowedUntypedLibraries": [
diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py
index 8479cdfb0c..2cc0192a4a 100644
--- a/api/schedule/mail_clean_document_notify_task.py
+++ b/api/schedule/mail_clean_document_notify_task.py
@@ -7,8 +7,8 @@ from sqlalchemy import select
import app
from configs import dify_config
+from core.db.session_factory import session_factory
from enums.cloud_plan import CloudPlan
-from extensions.ext_database import db
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
from models import Account, Tenant, TenantAccountJoin
@@ -33,67 +33,68 @@ def mail_clean_document_notify_task():
# send document clean notify mail
try:
- dataset_auto_disable_logs = db.session.scalars(
- select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
- ).all()
- # group by tenant_id
- dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
- for dataset_auto_disable_log in dataset_auto_disable_logs:
- if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
- dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
- dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
- url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
- for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
- features = FeatureService.get_features(tenant_id)
- plan = features.billing.subscription.plan
- if plan != CloudPlan.SANDBOX:
- knowledge_details = []
- # check tenant
- tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
- if not tenant:
- continue
- # check current owner
- current_owner_join = db.session.scalar(
- select(TenantAccountJoin)
- .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
- .limit(1)
- )
- if not current_owner_join:
- continue
- account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
- if not account:
- continue
+ with session_factory.create_session() as session:
+ dataset_auto_disable_logs = session.scalars(
+ select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False))
+ ).all()
+ # group by tenant_id
+ dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
+ for dataset_auto_disable_log in dataset_auto_disable_logs:
+ if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
+ dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
+ dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
+ url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
+ for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
+ features = FeatureService.get_features(tenant_id)
+ plan = features.billing.subscription.plan
+ if plan != CloudPlan.SANDBOX:
+ knowledge_details = []
+ # check tenant
+ tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id))
+ if not tenant:
+ continue
+ # check current owner
+ current_owner_join = session.scalar(
+ select(TenantAccountJoin)
+ .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
+ .limit(1)
+ )
+ if not current_owner_join:
+ continue
+ account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
+ if not account:
+ continue
- dataset_auto_dataset_map = {} # type: ignore
+ dataset_auto_dataset_map = {} # type: ignore
+ for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
+ if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
+ dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
+ dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
+ dataset_auto_disable_log.document_id
+ )
+
+ for dataset_id, document_ids in dataset_auto_dataset_map.items():
+ dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id))
+ if dataset:
+ document_count = len(document_ids)
+ knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
+ if knowledge_details:
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
+ language_code="en-US",
+ to=account.email,
+ template_context={
+ "userName": account.email,
+ "knowledge_details": knowledge_details,
+ "url": url,
+ },
+ )
+
+ # update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
- if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
- dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
- dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
- dataset_auto_disable_log.document_id
- )
-
- for dataset_id, document_ids in dataset_auto_dataset_map.items():
- dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
- if dataset:
- document_count = len(document_ids)
- knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
- if knowledge_details:
- email_service = get_email_i18n_service()
- email_service.send_email(
- email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
- language_code="en-US",
- to=account.email,
- template_context={
- "userName": account.email,
- "knowledge_details": knowledge_details,
- "url": url,
- },
- )
-
- # update notified to True
- for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
- dataset_auto_disable_log.notified = True
- db.session.commit()
+ dataset_auto_disable_log.notified = True
+ session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
except Exception:
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index 78806927bc..97aaea3395 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -17,6 +17,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
+from constants.dsl_version import CURRENT_APP_DSL_VERSION
from core.helper import ssrf_proxy
from core.plugin.entities.plugin import PluginDependency
from core.trigger.constants import (
@@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
-CURRENT_DSL_VERSION = "0.6.0"
+CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION
class Import(BaseModel):
diff --git a/api/services/app_service.py b/api/services/app_service.py
index afd98e2975..038c59633a 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index e6f5f80a6d..894cb05687 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -30,7 +30,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.file import helpers as file_helpers
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
diff --git a/api/services/feature_service.py b/api/services/feature_service.py
index e4eb9e7582..a6b02630bf 100644
--- a/api/services/feature_service.py
+++ b/api/services/feature_service.py
@@ -3,6 +3,7 @@ from enum import StrEnum
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
+from constants.dsl_version import CURRENT_APP_DSL_VERSION
from enums.cloud_plan import CloudPlan
from enums.hosted_provider import HostedTrialProvider
from services.billing_service import BillingService
@@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel):
class SystemFeatureModel(BaseModel):
+ app_dsl_version: str = ""
sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = ""
enable_marketplace: bool = False
@@ -225,6 +227,7 @@ class FeatureService:
@classmethod
def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
system_features = SystemFeatureModel()
+ system_features.app_dsl_version = CURRENT_APP_DSL_VERSION
cls._fulfill_system_params_from_env(system_features)
diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py
index 68ef67dec1..8b4983e5f7 100644
--- a/api/services/human_input_delivery_test_service.py
+++ b/api/services/human_input_delivery_test_service.py
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py
index aa7456dcd3..8c9a81af87 100644
--- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py
+++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py
@@ -50,7 +50,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param language: language
:return:
"""
- builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
+ builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
@@ -60,5 +60,5 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
:param template_id: Template ID
:return:
"""
- builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data()
+ builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data()
return builtin_data.get("pipeline_templates", {}).get(template_id)
diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py
index 0ffbef8365..9d446f6d4b 100644
--- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py
+++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import Any, TypedDict
import yaml
from sqlalchemy import select
@@ -10,6 +10,30 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
+class CustomizedTemplateItemDict(TypedDict):
+ id: str
+ name: str
+ description: str
+ icon: dict[str, Any]
+ position: int
+ chunk_structure: str
+
+
+class CustomizedTemplatesResultDict(TypedDict):
+ pipeline_templates: list[CustomizedTemplateItemDict]
+
+
+class CustomizedTemplateDetailDict(TypedDict):
+ id: str
+ name: str
+ icon_info: dict[str, Any]
+ description: str
+ chunk_structure: str
+ export_data: str
+ graph: dict[str, Any]
+ created_by: str
+
+
class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval recommended app from database
@@ -17,12 +41,10 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
_, current_tenant_id = current_account_with_tenant()
- result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
- return result
+ return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language)
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
- result = self.fetch_pipeline_template_detail_from_db(template_id)
- return result
+ return self.fetch_pipeline_template_detail_from_db(template_id)
def get_type(self) -> str:
return PipelineTemplateType.CUSTOMIZED
@@ -40,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
).all()
- recommended_pipelines_results = []
+ recommended_pipelines_results: list[CustomizedTemplateItemDict] = []
for pipeline_customized_template in pipeline_customized_templates:
- recommended_pipeline_result = {
+ recommended_pipeline_result: CustomizedTemplateItemDict = {
"id": pipeline_customized_template.id,
"name": pipeline_customized_template.name,
"description": pipeline_customized_template.description,
diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py
index 073eed221c..2964537c35 100644
--- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py
+++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py
@@ -1,4 +1,4 @@
-from typing import Any
+from typing import Any, TypedDict
import yaml
from sqlalchemy import select
@@ -9,18 +9,41 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel
from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType
+class PipelineTemplateItemDict(TypedDict):
+ id: str
+ name: str
+ description: str
+ icon: dict[str, Any]
+ copyright: str
+ privacy_policy: str
+ position: int
+ chunk_structure: str
+
+
+class PipelineTemplatesResultDict(TypedDict):
+ pipeline_templates: list[PipelineTemplateItemDict]
+
+
+class PipelineTemplateDetailDict(TypedDict):
+ id: str
+ name: str
+ icon_info: dict[str, Any]
+ description: str
+ chunk_structure: str
+ export_data: str
+ graph: dict[str, Any]
+
+
class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
Retrieval pipeline template from database
"""
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
- result = self.fetch_pipeline_templates_from_db(language)
- return result
+ return self.fetch_pipeline_templates_from_db(language)
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
- result = self.fetch_pipeline_template_detail_from_db(template_id)
- return result
+ return self.fetch_pipeline_template_detail_from_db(template_id)
def get_type(self) -> str:
return PipelineTemplateType.DATABASE
@@ -39,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
).all()
)
- recommended_pipelines_results = []
+ recommended_pipelines_results: list[PipelineTemplateItemDict] = []
for pipeline_built_in_template in pipeline_built_in_templates:
- recommended_pipeline_result = {
+ recommended_pipeline_result: PipelineTemplateItemDict = {
"id": pipeline_built_in_template.id,
"name": pipeline_built_in_template.name,
"description": pipeline_built_in_template.description,
diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py
index d5ef745bec..9565ac46cc 100644
--- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py
+++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py
@@ -17,21 +17,18 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None:
- result: dict[str, Any] | None
try:
- result = self.fetch_pipeline_template_detail_from_dify_official(template_id)
+ return self.fetch_pipeline_template_detail_from_dify_official(template_id)
except Exception as e:
logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e)
- result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
- return result
+ return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id)
def get_pipeline_templates(self, language: str) -> dict[str, Any]:
try:
- result = self.fetch_pipeline_templates_from_dify_official(language)
+ return self.fetch_pipeline_templates_from_dify_official(language)
except Exception as e:
logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e)
- result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
- return result
+ return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language)
def get_type(self) -> str:
return PipelineTemplateType.REMOTE
diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py
index 968600d1bc..9db6682e10 100644
--- a/api/services/rag_pipeline/rag_pipeline.py
+++ b/api/services/rag_pipeline/rag_pipeline.py
@@ -476,7 +476,7 @@ class RagPipelineService:
:param filters: filter by node config parameters.
:return:
"""
- node_type_enum = NodeType(node_type)
+ node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py
index a91f49e9e6..cf39469be8 100644
--- a/api/services/summary_index_service.py
+++ b/api/services/summary_index_service.py
@@ -349,7 +349,6 @@ class SummaryIndexService:
summary_record_id,
)
summary_record_in_session = DocumentSegmentSummary(
- id=summary_record_id, # Use the same ID if available
dataset_id=dataset.id,
document_id=segment.document_id,
chunk_id=segment.id,
@@ -360,6 +359,9 @@ class SummaryIndexService:
status=SummaryStatus.COMPLETED,
enabled=True,
)
+ if summary_record_in_session is None:
+ raise RuntimeError("summary_record_in_session should not be None at this point")
+ summary_record_in_session.id = summary_record_id
session.add(summary_record_in_session)
logger.info(
"Created new summary record (id=%s) for segment %s after vectorization",
diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py
index c96050ce13..1529c2b98f 100644
--- a/api/services/variable_truncator.py
+++ b/api/services/variable_truncator.py
@@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator):
return TruncationResult(StringSegment(value=fallback_result.value), True)
# Apply final fallback - convert to JSON string and truncate
- json_str = dumps_with_segments(result.value, ensure_ascii=False)
+ json_str = dumps_with_segments(result.value)
if len(json_str) > self._max_size_bytes:
json_str = json_str[: self._max_size_bytes] + "..."
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index 5ec00ee336..96f936ff9b 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader):
variable = segment_to_variable(
segment=segment,
selector=draft_var.get_selector(),
- id=draft_var.id,
+ variable_id=draft_var.id,
name=draft_var.name,
description=draft_var.description,
)
@@ -1067,7 +1067,7 @@ class DraftVariableSaver:
filename = f"{self._generate_filename(name)}.txt"
else:
# For other types, store as JSON
- original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
+ original_content_serialized = dumps_with_segments(value_seg.value)
content_type = "application/json"
filename = f"{self._generate_filename(name)}.json"
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index d01331b588..f97b85dc2b 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly,
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
from core.trigger.constants import is_trigger_node_type
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
- normalize_human_input_node_data_for_graph,
+ adapt_human_input_node_data_for_graph,
parse_human_input_delivery_methods,
)
from core.workflow.node_factory import (
@@ -802,7 +802,7 @@ class WorkflowService:
:param filters: filter by node config parameters.
:return:
"""
- node_type_enum = NodeType(node_type)
+ node_type_enum: NodeType = node_type
node_mapping = get_node_type_classes_mapping()
# return default block config
@@ -1107,7 +1107,7 @@ class WorkflowService:
raise ValueError("Node type must be human-input.")
node_data = HumanInputNodeData.model_validate(
- normalize_human_input_node_data_for_graph(node_config["data"]),
+ adapt_human_input_node_data_for_graph(node_config["data"]),
from_attributes=True,
)
delivery_method = self._resolve_human_input_delivery_method(
@@ -1248,9 +1248,10 @@ class WorkflowService:
variable_pool=variable_pool,
start_at=time.perf_counter(),
)
+ node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"]))
node = HumanInputNode(
- id=node_config["id"],
- config=node_config,
+ node_id=node_config["id"],
+ config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(run_context),
@@ -1540,7 +1541,7 @@ class WorkflowService:
from graphon.nodes.human_input.entities import HumanInputNodeData
try:
- HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data))
+ HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data))
except Exception as e:
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py
index f8ae3f4b6e..2a60be7762 100644
--- a/api/tasks/mail_human_input_delivery_task.py
+++ b/api/tasks/mail_human_input_delivery_task.py
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
-from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod
+from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod
from extensions.ext_database import db
from extensions.ext_mail import mail
from graphon.runtime import GraphRuntimeState, VariablePool
diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
index b5318aaa2b..2392084c36 100644
--- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
+++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
+from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamCompletedEvent
@@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
- id="n",
- config={
- "id": "n",
- "data": {
- "type": "datasource",
- "version": "1",
- "title": "Datasource",
- "provider_type": "plugin",
- "provider_name": "p",
- "plugin_id": "plug",
- "datasource_name": "ds",
- },
- },
+ node_id="n",
+ config=DatasourceNodeData(
+ type="datasource",
+ version="1",
+ title="Datasource",
+ provider_type="plugin",
+ provider_name="p",
+ plugin_id="plug",
+ datasource_name="ds",
+ ),
graph_init_params=_GP(),
graph_runtime_state=_GS(vp),
)
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py
index e3476c292b..aaa6092993 100644
--- a/api/tests/integration_tests/workflow/nodes/test_code.py
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import NodeRunResult
from graphon.nodes.code.code_node import CodeNode
+from graphon.nodes.code.entities import CodeNodeData
from graphon.nodes.code.limits import CodeNodeLimits
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -64,8 +65,8 @@ def init_code_node(code_config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = CodeNode(
- id=str(uuid.uuid4()),
- config=code_config,
+ node_id=str(uuid.uuid4()),
+ config=CodeNodeData.model_validate(code_config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
code_executor=node_factory._code_executor,
diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py
index aa6cf1e021..b9f7b9575b 100644
--- a/api/tests/integration_tests/workflow/nodes/test_http.py
+++ b/api/tests/integration_tests/workflow/nodes/test_http.py
@@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.graph import Graph
-from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
+from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -75,8 +75,8 @@ def init_http_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=HttpRequestNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
@@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
node = HttpRequestNode(
- id=str(uuid.uuid4()),
- config=graph_config["nodes"][1],
+ node_id=str(uuid.uuid4()),
+ config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index fa5d63cfbf..3eead70163 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import StreamCompletedEvent
+from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
@@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode:
llm_file_saver = MagicMock(spec=LLMFileSaver)
node = LLMNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
index 52886855b8..f2eabb86c3 100644
--- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
+++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
@@ -11,6 +11,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
+from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
@@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = ParameterExtractorNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=ParameterExtractorNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
index 9e3e1a47e3..e2e0723fb8 100644
--- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py
+++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
@@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
+from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.template_rendering import TemplateRenderError
@@ -86,8 +87,8 @@ def test_execute_template_transform():
assert graph is not None
node = TemplateTransformNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=TemplateTransformNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
jinja2_template_renderer=_SimpleJinja2Renderer(),
diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py
index f9ec51ee10..a8e9422c1e 100644
--- a/api/tests/integration_tests/workflow/nodes/test_tool.py
+++ b/api/tests/integration_tests/workflow/nodes/test_tool.py
@@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.protocols import ToolFileManagerProtocol
+from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool.tool_node import ToolNode
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -60,8 +61,8 @@ def init_tool_node(config: dict):
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
- id=str(uuid.uuid4()),
- config=config,
+ node_id=str(uuid.uuid4()),
+ config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
index 14d5740072..6524d6ce61 100644
--- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
+++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py
@@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
from sqlalchemy.orm import Session
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryChannelConfig,
EmailDeliveryConfig,
EmailDeliveryMethod,
diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
index da4f8847d6..5aed230cd4 100644
--- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
+++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py
@@ -101,8 +101,8 @@ def _build_graph(
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
- id="start",
- config={"id": "start", "data": start_data.model_dump()},
+ node_id="start",
+ config=start_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
@@ -116,8 +116,8 @@ def _build_graph(
],
)
human_node = HumanInputNode(
- id="human",
- config={"id": "human", "data": human_data.model_dump()},
+ node_id="human",
+ config=human_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
@@ -130,8 +130,8 @@ def _build_graph(
desc=None,
)
end_node = EndNode(
- id="end",
- config={"id": "end", "data": end_data.model_dump()},
+ node_id="end",
+ config=end_data,
graph_init_params=params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
index 2e207ddc67..35e41035df 100644
--- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
+++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py
@@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase):
file_related_id = related_id
return File(
- id=str(uuid4()), # Generate new UUID for File.id
+ file_id=str(uuid4()), # Generate new UUID for File.id
tenant_id=tenant_id,
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=transfer_method,
related_id=file_related_id,
remote_url=remote_url,
diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
index aaf9a85d60..54b7afc018 100644
--- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
+++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py
@@ -271,7 +271,7 @@ def _create_recipient(
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
- from core.workflow.human_input_compat import DeliveryMethodType
+ from core.workflow.human_input_adapter import DeliveryMethodType
from models.human_input import ConsoleDeliveryPayload
delivery = HumanInputDelivery(
diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py
new file mode 100644
index 0000000000..2bec703f0c
--- /dev/null
+++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py
@@ -0,0 +1,650 @@
+"""Testcontainers integration tests for SQL-backed DocumentService paths."""
+
+import datetime
+import json
+from unittest.mock import create_autospec, patch
+from uuid import uuid4
+
+import pytest
+from werkzeug.exceptions import Forbidden, NotFound
+
+from core.rag.index_processor.constant.index_type import IndexStructureType
+from extensions.storage.storage_type import StorageType
+from models import Account
+from models.dataset import Dataset, Document
+from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus
+from models.model import UploadFile
+from services.dataset_service import DocumentService
+from services.errors.account import NoPermissionError
+
+FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0)
+
+
+class DocumentServiceIntegrationFactory:
+ @staticmethod
+ def create_dataset(
+ db_session_with_containers,
+ *,
+ tenant_id: str | None = None,
+ created_by: str | None = None,
+ name: str | None = None,
+ ) -> Dataset:
+ dataset = Dataset(
+ tenant_id=tenant_id or str(uuid4()),
+ name=name or f"dataset-{uuid4()}",
+ data_source_type=DataSourceType.UPLOAD_FILE,
+ created_by=created_by or str(uuid4()),
+ )
+ db_session_with_containers.add(dataset)
+ db_session_with_containers.commit()
+ return dataset
+
+ @staticmethod
+ def create_document(
+ db_session_with_containers,
+ *,
+ dataset: Dataset,
+ name: str = "doc.txt",
+ position: int = 1,
+ tenant_id: str | None = None,
+ indexing_status: str = IndexingStatus.COMPLETED,
+ enabled: bool = True,
+ archived: bool = False,
+ is_paused: bool = False,
+ need_summary: bool = False,
+ doc_form: str = IndexStructureType.PARAGRAPH_INDEX,
+ batch: str | None = None,
+ data_source_type: str = DataSourceType.UPLOAD_FILE,
+ data_source_info: dict | None = None,
+ created_by: str | None = None,
+ ) -> Document:
+ document = Document(
+ tenant_id=tenant_id or dataset.tenant_id,
+ dataset_id=dataset.id,
+ position=position,
+ data_source_type=data_source_type,
+ data_source_info=json.dumps(data_source_info or {}),
+ batch=batch or f"batch-{uuid4()}",
+ name=name,
+ created_from=DocumentCreatedFrom.WEB,
+ created_by=created_by or dataset.created_by,
+ doc_form=doc_form,
+ )
+ document.indexing_status = indexing_status
+ document.enabled = enabled
+ document.archived = archived
+ document.is_paused = is_paused
+ document.need_summary = need_summary
+ if indexing_status == IndexingStatus.COMPLETED:
+ document.completed_at = FIXED_UPLOAD_CREATED_AT
+ db_session_with_containers.add(document)
+ db_session_with_containers.commit()
+ return document
+
+ @staticmethod
+ def create_upload_file(
+ db_session_with_containers,
+ *,
+ tenant_id: str,
+ created_by: str,
+ file_id: str | None = None,
+ name: str = "source.txt",
+ ) -> UploadFile:
+ upload_file = UploadFile(
+ tenant_id=tenant_id,
+ storage_type=StorageType.LOCAL,
+ key=f"uploads/{uuid4()}",
+ name=name,
+ size=128,
+ extension="txt",
+ mime_type="text/plain",
+ created_by_role=CreatorUserRole.ACCOUNT,
+ created_by=created_by,
+ created_at=FIXED_UPLOAD_CREATED_AT,
+ used=False,
+ )
+ if file_id:
+ upload_file.id = file_id
+ db_session_with_containers.add(upload_file)
+ db_session_with_containers.commit()
+ return upload_file
+
+
+@pytest.fixture
+def current_user_mock():
+ with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
+ current_user.id = str(uuid4())
+ current_user.current_tenant_id = str(uuid4())
+ current_user.current_role = None
+ yield current_user
+
+
+def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ assert DocumentService.get_document(dataset.id, None) is None
+
+
+def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset)
+
+ result = DocumentService.get_document(dataset.id, document.id)
+
+ assert result is not None
+ assert result.id == document.id
+
+
+def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ result = DocumentService.get_documents_by_ids(dataset.id, [])
+
+ assert result == []
+
+
+def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt")
+ doc_b = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ name="b.txt",
+ position=2,
+ )
+
+ result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id])
+
+ assert {document.id for document in result} == {doc_a.id, doc_b.id}
+
+
+def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ assert DocumentService.update_documents_need_summary(dataset.id, []) == 0
+
+
+def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ paragraph_doc = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ need_summary=True,
+ )
+ qa_doc = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ need_summary=True,
+ doc_form=IndexStructureType.QA_INDEX,
+ )
+
+ updated_count = DocumentService.update_documents_need_summary(
+ dataset.id,
+ [paragraph_doc.id, qa_doc.id],
+ need_summary=False,
+ )
+
+ db_session_with_containers.expire_all()
+ refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id)
+ refreshed_qa = db_session_with_containers.get(Document, qa_doc.id)
+ assert updated_count == 1
+ assert refreshed_paragraph is not None
+ assert refreshed_qa is not None
+ assert refreshed_paragraph.need_summary is False
+ assert refreshed_qa.need_summary is True
+
+
+def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file.id},
+ )
+
+ with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url:
+ result = DocumentService.get_document_download_url(document)
+
+ assert result == "signed-url"
+ get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True)
+
+
+def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_type=DataSourceType.WEBSITE_CRAWL,
+ data_source_info={"url": "https://example.com"},
+ )
+
+ with pytest.raises(NotFound, match="invalid source"):
+ DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="invalid source",
+ missing_file_message="missing file",
+ )
+
+
+def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={},
+ )
+
+ with pytest.raises(NotFound, match="missing file"):
+ DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="invalid source",
+ missing_file_message="missing file",
+ )
+
+
+def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": 99},
+ )
+
+ result = DocumentService._get_upload_file_id_for_upload_file_document(
+ document,
+ invalid_source_message="invalid source",
+ missing_file_message="missing file",
+ )
+
+ assert result == "99"
+
+
+def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": "missing-file"},
+ )
+
+ with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}):
+ with pytest.raises(NotFound, match="Uploaded file not found"):
+ DocumentService._get_upload_file_for_upload_file_document(document)
+
+
+def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file.id},
+ )
+
+ result = DocumentService._get_upload_file_for_upload_file_document(document)
+
+ assert result.id == upload_file.id
+
+
+def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ with pytest.raises(NotFound, match="Document not found"):
+ DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset.id,
+ document_ids=[str(uuid4())],
+ tenant_id=dataset.tenant_id,
+ )
+
+
+def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ tenant_id=str(uuid4()),
+ data_source_info={"upload_file_id": upload_file.id},
+ )
+
+ with pytest.raises(Forbidden, match="No permission"):
+ DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset.id,
+ document_ids=[document.id],
+ tenant_id=dataset.tenant_id,
+ )
+
+
+def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": str(uuid4())},
+ )
+
+ with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"):
+ DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset.id,
+ document_ids=[document.id],
+ tenant_id=dataset.tenant_id,
+ )
+
+
+def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="a.txt",
+ )
+ upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="b.txt",
+ )
+ document_a = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file_a.id},
+ )
+ document_b = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ data_source_info={"upload_file_id": upload_file_b.id},
+ )
+
+ mapping = DocumentService._get_upload_files_by_document_id_for_zip_download(
+ dataset_id=dataset.id,
+ document_ids=[document_a.id, document_b.id],
+ tenant_id=dataset.tenant_id,
+ )
+
+ assert mapping[document_a.id].id == upload_file_a.id
+ assert mapping[document_b.id].id == upload_file_b.id
+
+
+def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset(
+ current_user_mock, flask_app_with_containers
+):
+ with flask_app_with_containers.app_context():
+ with pytest.raises(NotFound, match="Dataset not found"):
+ DocumentService.prepare_document_batch_download_zip(
+ dataset_id=str(uuid4()),
+ document_ids=[str(uuid4())],
+ tenant_id=current_user_mock.current_tenant_id,
+ current_user=current_user_mock,
+ )
+
+
+def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(
+ db_session_with_containers,
+ current_user_mock,
+):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(
+ db_session_with_containers,
+ tenant_id=current_user_mock.current_tenant_id,
+ created_by=current_user_mock.id,
+ )
+
+ with patch(
+ "services.dataset_service.DatasetService.check_dataset_permission",
+ side_effect=NoPermissionError("denied"),
+ ):
+ with pytest.raises(Forbidden, match="denied"):
+ DocumentService.prepare_document_batch_download_zip(
+ dataset_id=dataset.id,
+ document_ids=[],
+ tenant_id=current_user_mock.current_tenant_id,
+ current_user=current_user_mock,
+ )
+
+
+def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(
+ db_session_with_containers,
+ current_user_mock,
+):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(
+ db_session_with_containers,
+ tenant_id=current_user_mock.current_tenant_id,
+ created_by=current_user_mock.id,
+ )
+ upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="a.txt",
+ )
+ upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="b.txt",
+ )
+ document_a = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file_a.id},
+ )
+ document_b = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ data_source_info={"upload_file_id": upload_file_b.id},
+ )
+
+ upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
+ dataset_id=dataset.id,
+ document_ids=[document_b.id, document_a.id],
+ tenant_id=current_user_mock.current_tenant_id,
+ current_user=current_user_mock,
+ )
+
+ assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id]
+ assert download_name.endswith(".zip")
+
+
+def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ enabled_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ enabled=True,
+ )
+ DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ enabled=False,
+ )
+
+ result = DocumentService.get_document_by_dataset_id(dataset.id)
+
+ assert [document.id for document in result] == [enabled_document.id]
+
+
+def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ available_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ indexing_status=IndexingStatus.COMPLETED,
+ enabled=True,
+ archived=False,
+ )
+ DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ indexing_status=IndexingStatus.ERROR,
+ )
+
+ result = DocumentService.get_working_documents_by_dataset_id(dataset.id)
+
+ assert [document.id for document in result] == [available_document.id]
+
+
+def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ error_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ indexing_status=IndexingStatus.ERROR,
+ )
+ paused_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ indexing_status=IndexingStatus.PAUSED,
+ )
+ DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=3,
+ indexing_status=IndexingStatus.COMPLETED,
+ )
+
+ result = DocumentService.get_error_documents_by_dataset_id(dataset.id)
+
+ assert {document.id for document in result} == {error_document.id, paused_document.id}
+
+
+def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ batch = f"batch-{uuid4()}"
+ matching_document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ batch=batch,
+ )
+ DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ tenant_id=str(uuid4()),
+ batch=batch,
+ )
+
+ with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
+ current_user.current_tenant_id = dataset.tenant_id
+ result = DocumentService.get_batch_documents(dataset.id, batch)
+
+ assert [document.id for document in result] == [matching_document.id]
+
+
+def test_get_document_file_detail_returns_upload_file(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+
+ result = DocumentService.get_document_file_detail(upload_file.id)
+
+ assert result is not None
+ assert result.id == upload_file.id
+
+
+def test_delete_document_emits_signal_and_commits(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ upload_file = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ )
+ document = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file.id},
+ )
+
+ with patch("services.dataset_service.document_was_deleted.send") as signal_send:
+ DocumentService.delete_document(document)
+
+ assert db_session_with_containers.get(Document, document.id) is None
+ signal_send.assert_called_once_with(
+ document.id,
+ dataset_id=document.dataset_id,
+ doc_form=document.doc_form,
+ file_id=upload_file.id,
+ )
+
+
+def test_delete_documents_ignores_empty_input(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
+ DocumentService.delete_documents(dataset, [])
+
+ delay.assert_not_called()
+
+
+def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX
+ db_session_with_containers.commit()
+ upload_file_a = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="a.txt",
+ )
+ upload_file_b = DocumentServiceIntegrationFactory.create_upload_file(
+ db_session_with_containers,
+ tenant_id=dataset.tenant_id,
+ created_by=dataset.created_by,
+ name="b.txt",
+ )
+ document_a = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ data_source_info={"upload_file_id": upload_file_a.id},
+ )
+ document_b = DocumentServiceIntegrationFactory.create_document(
+ db_session_with_containers,
+ dataset=dataset,
+ position=2,
+ data_source_info={"upload_file_id": upload_file_b.id},
+ )
+
+ with patch("services.dataset_service.batch_clean_document_task.delay") as delay:
+ DocumentService.delete_documents(dataset, [document_a.id, document_b.id])
+
+ assert db_session_with_containers.get(Document, document_a.id) is None
+ assert db_session_with_containers.get(Document, document_b.id) is None
+ delay.assert_called_once()
+ args = delay.call_args.args
+ assert args[0] == [document_a.id, document_b.id]
+ assert args[1] == dataset.id
+ assert set(args[3]) == {upload_file_a.id, upload_file_b.id}
+
+
+def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+ DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3)
+
+ assert DocumentService.get_documents_position(dataset.id) == 4
+
+
+def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers):
+ dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers)
+
+ assert DocumentService.get_documents_position(dataset.id) == 1
diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
index 18c5320d0a..80f9083e81 100644
--- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
+++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
index 21a54e909e..ed75363f3b 100644
--- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
+++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py
@@ -8,7 +8,7 @@ import pytest
from sqlalchemy.engine import Engine
from configs import dify_config
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
index 328bdbf055..95a867dbb5 100644
--- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
+++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py
@@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py
index 6ff3b19362..e91c0a0597 100644
--- a/api/tests/unit_tests/controllers/console/app/test_workflow.py
+++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py
@@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
file_list = [
File(
tenant_id="t1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)
diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
index b19a1740eb..22b80b748e 100644
--- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
+++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py
@@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url():
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
test_file = File(
- id="test_file_id",
- type=FileType.IMAGE,
+ file_id="test_file_id",
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_upload_file_id",
filename="test.jpg",
@@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url():
# Create a File object with REMOTE_URL transfer method
test_file = File(
- id="test_file_id",
- type=FileType.IMAGE,
+ file_id="test_file_id",
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/test.jpg",
filename="test.jpg",
diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py
index 0895fac3a4..d1b09c3a58 100644
--- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py
+++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py
@@ -41,17 +41,22 @@ class TestTenantUserPayload:
class TestGetUser:
"""Test get_user function"""
+ @patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db")
- def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
+ def test_should_return_existing_user_by_id(
+ self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
+ ):
"""Test returning existing user when found by ID"""
# Arrange
mock_user = MagicMock()
mock_user.id = "user123"
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
- mock_session.get.return_value = mock_user
+ mock_session.scalar.return_value = mock_user
+ mock_query = MagicMock()
+ mock_select.return_value.where.return_value.limit.return_value = mock_query
# Act
with app.app_context():
@@ -59,13 +64,45 @@ class TestGetUser:
# Assert
assert result == mock_user
- mock_session.get.assert_called_once()
+ mock_session.scalar.assert_called_once()
+ @patch("controllers.inner_api.plugin.wraps.select")
+ @patch("controllers.inner_api.plugin.wraps.EndUser")
+ @patch("controllers.inner_api.plugin.wraps.sessionmaker")
+ @patch("controllers.inner_api.plugin.wraps.db")
+ def test_should_not_resolve_non_anonymous_users_across_tenants(
+ self,
+ mock_db,
+ mock_sessionmaker,
+ mock_enduser_class,
+ mock_select,
+ app: Flask,
+ ):
+ """Test that explicit user IDs remain scoped to the current tenant."""
+ # Arrange
+ mock_session = MagicMock()
+ mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = None
+ mock_new_user = MagicMock()
+ mock_new_user.tenant_id = "tenant-current"
+ mock_enduser_class.return_value = mock_new_user
+
+ # Act
+ with app.app_context():
+ result = get_user("tenant-current", "foreign-user-id")
+
+ # Assert
+ assert result == mock_new_user
+ mock_session.get.assert_not_called()
+ mock_session.scalar.assert_called_once()
+ mock_session.add.assert_called_once_with(mock_new_user)
+
+ @patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db")
def test_should_return_existing_anonymous_user_by_session_id(
- self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask
+ self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
):
"""Test returning existing anonymous user by session_id"""
# Arrange
@@ -73,8 +110,9 @@ class TestGetUser:
mock_user.session_id = "anonymous_session"
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
- # non-anonymous path uses session.get(); anonymous uses session.scalar()
- mock_session.get.return_value = mock_user
+ mock_session.scalar.return_value = mock_user
+ mock_query = MagicMock()
+ mock_select.return_value.where.return_value.limit.return_value = mock_query
# Act
with app.app_context():
@@ -83,17 +121,22 @@ class TestGetUser:
# Assert
assert result == mock_user
+ @patch("controllers.inner_api.plugin.wraps.select")
@patch("controllers.inner_api.plugin.wraps.EndUser")
@patch("controllers.inner_api.plugin.wraps.sessionmaker")
@patch("controllers.inner_api.plugin.wraps.db")
- def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask):
+ def test_should_create_new_user_when_not_found(
+ self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask
+ ):
"""Test creating new user when not found in database"""
# Arrange
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
- mock_session.get.return_value = None
+ mock_session.scalar.return_value = None
mock_new_user = MagicMock()
mock_enduser_class.return_value = mock_new_user
+ mock_query = MagicMock()
+ mock_select.return_value.where.return_value.limit.return_value = mock_query
# Act
with app.app_context():
@@ -134,7 +177,7 @@ class TestGetUser:
# Arrange
mock_session = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
- mock_session.get.side_effect = Exception("Database error")
+ mock_session.scalar.side_effect = Exception("Database error")
# Act & Assert
with app.app_context():
diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
index 14c35a9ed5..4fb8ecf784 100644
--- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
+++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py
@@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import (
ConversationVariableUpdatePayload,
)
from controllers.service_api.app.error import NotChatAppError
+from fields._value_type_serializer import serialize_value_type
+from graphon.variables import StringSegment
from graphon.variables.types import SegmentType
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@@ -284,6 +286,32 @@ class TestConversationVariableResponseModels:
assert response.created_at == int(created_at.timestamp())
assert response.updated_at == int(created_at.timestamp())
+ def test_variable_response_normalizes_string_value_type_alias(self):
+ response = ConversationVariableResponse.model_validate(
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "foo",
+ "value_type": SegmentType.INTEGER.value,
+ }
+ )
+
+ assert response.value_type == "number"
+
+ def test_variable_response_normalizes_callable_exposed_type(self):
+ response = ConversationVariableResponse.model_validate(
+ {
+ "id": "550e8400-e29b-41d4-a716-446655440000",
+ "name": "foo",
+ "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()),
+ }
+ )
+
+ assert response.value_type == "string"
+
+ def test_serialize_value_type_supports_segments_and_mappings(self):
+ assert serialize_value_type(StringSegment(value="hello")) == "string"
+ assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number"
+
def test_variable_pagination_response(self):
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
{
diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
index 3ab63aed25..dd6cd0e919 100644
--- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
+++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py
@@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
def create_test_file(self, file_id: str = "test_file_1") -> File:
"""Create a test File object"""
return File(
- id=file_id,
- type=FileType.DOCUMENT,
+ file_id=file_id,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related_123",
filename=f"{file_id}.txt",
diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py
index a04a7b7576..6104b8d6ca 100644
--- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py
+++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py
@@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow import node_factory as node_factory_module
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.system_variables import build_system_variables
from graphon.entities import WorkflowStartReason
from graphon.entities.base_node_data import BaseNodeData, RetryConfig
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.entities.pause_reason import SchedulingPause
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus
from graphon.graph import Graph
@@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]):
def version(cls) -> str:
return "1"
- def init_node_data(self, data):
- self._node_data = _StubToolNodeData.model_validate(data)
+ def __init__(
+ self,
+ node_id: str,
+ config: _StubToolNodeData,
+ *,
+ graph_init_params,
+ graph_runtime_state,
+ **_kwargs: Any,
+ ) -> None:
+ super().__init__(
+ node_id=node_id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
def _get_error_strategy(self):
return self._node_data.error_strategy
@@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]):
def _patch_tool_node(mocker):
- original_create_node = DifyNodeFactory.create_node
+ original_resolve_node_class = node_factory_module.resolve_workflow_node_class
- def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
- typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
- node_data = typed_node_config["data"]
- if node_data.type == BuiltinNodeTypes.TOOL:
- return _StubToolNode(
- id=str(typed_node_config["id"]),
- config=typed_node_config,
- graph_init_params=self.graph_init_params,
- graph_runtime_state=self.graph_runtime_state,
- )
- return original_create_node(self, typed_node_config)
+ def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
+ if node_type == BuiltinNodeTypes.TOOL:
+ return _StubToolNode
+ return original_resolve_node_class(node_type=node_type, node_version=node_version)
- mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
+ mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class)
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
index cddd03f4b0..701863b927 100644
--- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
+++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py
@@ -26,8 +26,8 @@ def _build_file(
extension: str | None = None,
) -> File:
return File(
- id="file-id",
- type=FileType.IMAGE,
+ file_id="file-id",
+ file_type=FileType.IMAGE,
transfer_method=transfer_method,
reference=reference,
remote_url=remote_url,
@@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
assert runtime.multimodal_send_format == "url"
- with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get:
+ with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
assert runtime.http_get("http://example", follow_redirects=False) == "response"
mock_get.assert_called_once_with("http://example", follow_redirects=False)
diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py
index c4bfb23272..30a068f4c5 100644
--- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py
+++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py
@@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes
class DummyNode:
- def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
- self.id = id
+ def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
+ self.id = node_id
self.config = config
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py
index 81315d2508..deeac49bbc 100644
--- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py
+++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py
@@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE)
built = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tool_file_1",
extension=".png",
@@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
file_in = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="tf",
extension=".pdf",
diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py
index a0b2820157..aeca2e3afd 100644
--- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py
+++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py
@@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
# Assert
assert simple_provider.provider == "openai"
- assert simple_provider.label.en_US == "OpenAI"
+ assert simple_provider.label.en_us == "OpenAI"
assert simple_provider.supported_model_types == [ModelType.LLM]
diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py
index fe2c226843..a28143026f 100644
--- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py
+++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py
@@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
)
]
)
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key")
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
+ encrypted_config="encrypted-old-key"
+ )
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
- with patch(
- "core.entities.provider_configuration.encrypter.encrypt_token",
- side_effect=lambda tenant_id, value: f"enc::{value}",
- ):
- validated = configuration.validate_provider_credentials(
- credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
- credential_id="credential-1",
- session=session,
- )
+ with _patched_session(mock_session):
+ with patch(
+ "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
+ ):
+ with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"):
+ with patch(
+ "core.entities.provider_configuration.encrypter.encrypt_token",
+ side_effect=lambda tenant_id, value: f"enc::{value}",
+ ):
+ validated = configuration.validate_provider_credentials(
+ credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"},
+ credential_id="credential-1",
+ )
assert validated["openai_api_key"] == "enc::restored-key"
assert validated["region"] == "us"
@@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None:
)
-def test_validate_provider_credentials_opens_session_when_not_passed() -> None:
+def test_validate_provider_credentials_without_credential_id() -> None:
configuration = _build_provider_configuration()
- mock_session = Mock()
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"region": "us"}
- with patch("core.entities.provider_configuration.Session") as mock_session_cls:
- with patch("core.entities.provider_configuration.db") as mock_db:
- mock_db.engine = Mock()
- mock_session_cls.return_value.__enter__.return_value = mock_session
- with patch(
- "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
- ):
- validated = configuration.validate_provider_credentials(credentials={"region": "us"})
+ with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
+ validated = configuration.validate_provider_credentials(credentials={"region": "us"})
assert validated == {"region": "us"}
- mock_session_cls.assert_called_once()
def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None:
@@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non
def test_validate_provider_credentials_handles_invalid_original_json() -> None:
configuration = _build_provider_configuration()
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
+ encrypted_config="{invalid-json"
+ )
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
- validated = configuration.validate_provider_credentials(
- credentials={"openai_api_key": HIDDEN_VALUE},
- credential_id="cred-1",
- session=session,
- )
+ with _patched_session(mock_session):
+ with patch(
+ "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
+ ):
+ with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"):
+ validated = configuration.validate_provider_credentials(
+ credentials={"openai_api_key": HIDDEN_VALUE},
+ credential_id="cred-1",
+ )
assert validated == {"openai_api_key": "enc-key"}
@@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback(
def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None:
configuration = _build_provider_configuration()
configuration.provider.model_credential_schema = _build_secret_model_schema()
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
encrypted_config='{"openai_api_key":"enc"}'
)
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
- with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
- validated = configuration.validate_custom_model_credentials(
- model_type=ModelType.LLM,
- model="gpt-4o",
- credentials={"openai_api_key": HIDDEN_VALUE},
- credential_id="cred-1",
- session=session,
- )
- assert validated == {"openai_api_key": "enc-new"}
-
- session = Mock()
- mock_factory = Mock()
- mock_factory.model_credentials_validate.return_value = {"region": "us"}
- with _patched_session(session):
+ with _patched_session(mock_session):
with patch(
"core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
):
- validated = configuration.validate_custom_model_credentials(
- model_type=ModelType.LLM,
- model="gpt-4o",
- credentials={"region": "us"},
- )
+ with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"):
+ with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
+ validated = configuration.validate_custom_model_credentials(
+ model_type=ModelType.LLM,
+ model="gpt-4o",
+ credentials={"openai_api_key": HIDDEN_VALUE},
+ credential_id="cred-1",
+ )
+ assert validated == {"openai_api_key": "enc-new"}
+
+ mock_factory2 = Mock()
+ mock_factory2.model_credentials_validate.return_value = {"region": "us"}
+ with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2):
+ validated = configuration.validate_custom_model_credentials(
+ model_type=ModelType.LLM,
+ model="gpt-4o",
+ credentials={"region": "us"},
+ )
assert validated == {"region": "us"}
@@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None:
def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None:
configuration = _build_provider_configuration()
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = None
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = None
mock_factory = Mock()
mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
- validated = configuration.validate_provider_credentials(
- credentials={"openai_api_key": HIDDEN_VALUE},
- credential_id="cred-1",
- session=session,
- )
+ with _patched_session(mock_session):
+ with patch(
+ "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
+ ):
+ with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
+ validated = configuration.validate_provider_credentials(
+ credentials={"openai_api_key": HIDDEN_VALUE},
+ credential_id="cred-1",
+ )
assert validated == {"openai_api_key": "enc-new"}
@@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None:
def test_validate_custom_model_credentials_handles_invalid_original_json() -> None:
configuration = _build_provider_configuration()
configuration.provider.model_credential_schema = _build_secret_model_schema()
- session = Mock()
- session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json")
+ mock_session = Mock()
+ mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
+ encrypted_config="{invalid-json"
+ )
mock_factory = Mock()
mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"}
- with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory):
- with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
- validated = configuration.validate_custom_model_credentials(
- model_type=ModelType.LLM,
- model="gpt-4o",
- credentials={"openai_api_key": HIDDEN_VALUE},
- credential_id="cred-1",
- session=session,
- )
+ with _patched_session(mock_session):
+ with patch(
+ "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory
+ ):
+ with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"):
+ validated = configuration.validate_custom_model_credentials(
+ model_type=ModelType.LLM,
+ model="gpt-4o",
+ credentials={"openai_api_key": HIDDEN_VALUE},
+ credential_id="cred-1",
+ )
assert validated == {"openai_api_key": "enc-new"}
diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py
index bb6e40e224..8cb0938575 100644
--- a/api/tests/unit_tests/core/file/test_models.py
+++ b/api/tests/unit_tests/core/file/test_models.py
@@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType
def test_file():
file = File(
- id="test-file",
+ file_id="test-file",
tenant_id="test-tenant-id",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="test-related-id",
filename="image.png",
@@ -25,27 +25,21 @@ def test_file():
assert file.size == 67
-def test_file_model_validate_accepts_legacy_tenant_id():
- data = {
- "id": "test-file",
- "tenant_id": "test-tenant-id",
- "type": "image",
- "transfer_method": "tool_file",
- "related_id": "test-related-id",
- "filename": "image.png",
- "extension": ".png",
- "mime_type": "image/png",
- "size": 67,
- "storage_key": "test-storage-key",
- "url": "https://example.com/image.png",
- # Extra legacy fields
- "tool_file_id": "tool-file-123",
- "upload_file_id": "upload-file-456",
- "datasource_file_id": "datasource-file-789",
- }
+def test_file_constructor_accepts_legacy_tenant_id():
+ file = File(
+ file_id="test-file",
+ tenant_id="test-tenant-id",
+ file_type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ tool_file_id="tool-file-123",
+ filename="image.png",
+ extension=".png",
+ mime_type="image/png",
+ size=67,
+ storage_key="test-storage-key",
+ url="https://example.com/image.png",
+ )
- file = File.model_validate(data)
-
- assert file.related_id == "test-related-id"
+ assert file.related_id == "tool-file-123"
assert file.storage_key == "test-storage-key"
assert "tenant_id" not in file.model_dump()
diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
index 3b5c5e6597..d9fed9ae2a 100644
--- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
+++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py
@@ -1,11 +1,17 @@
from unittest.mock import MagicMock, patch
+import httpx
import pytest
from core.helper.ssrf_proxy import (
SSRF_DEFAULT_MAX_RETRIES,
+ SSRFProxy,
_get_user_provided_host_header,
+ _to_graphon_http_response,
+ graphon_ssrf_proxy,
make_request,
+ max_retries_exceeded_error,
+ request_error,
)
@@ -174,3 +180,56 @@ class TestFollowRedirectsParameter:
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
+
+
+def test_to_graphon_http_response_preserves_httpx_response_fields() -> None:
+ response = httpx.Response(
+ 201,
+ headers={"X-Test": "1"},
+ content=b"payload",
+ request=httpx.Request("GET", "https://example.com/resource"),
+ )
+
+ wrapped = _to_graphon_http_response(response)
+
+ assert wrapped.status_code == 201
+ assert wrapped.headers == {"x-test": "1", "content-length": "7"}
+ assert wrapped.content == b"payload"
+ assert wrapped.url == "https://example.com/resource"
+ assert wrapped.reason_phrase == "Created"
+ assert wrapped.text == "payload"
+
+
+def test_ssrf_proxy_exposes_expected_error_types() -> None:
+ proxy = SSRFProxy()
+
+ assert proxy.max_retries_exceeded_error is max_retries_exceeded_error
+ assert proxy.request_error is request_error
+ assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error
+ assert graphon_ssrf_proxy.request_error is request_error
+
+
+@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
+def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None:
+ response = httpx.Response(
+ 200,
+ headers={"X-Test": "1"},
+ content=b"ok",
+ request=httpx.Request("GET", "https://example.com/resource"),
+ )
+
+ with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method:
+ wrapped = getattr(graphon_ssrf_proxy, method_name)(
+ "https://example.com/resource",
+ max_retries=3,
+ headers={"X-Test": "1"},
+ )
+
+ mock_method.assert_called_once_with(
+ url="https://example.com/resource",
+ max_retries=3,
+ headers={"X-Test": "1"},
+ )
+ assert wrapped.status_code == 200
+ assert wrapped.url == "https://example.com/resource"
+ assert wrapped.content == b"ok"
diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
index 249ecb5006..c4fd970562 100644
--- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
+++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py
@@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderCredentialSchema,
ProviderEntity,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
-from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
-from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
-from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
+from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
+from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
+from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
+from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py
index 2cbff54c42..69650c85cc 100644
--- a/api/tests/unit_tests/core/ops/test_config_entity.py
+++ b/api/tests/unit_tests/core/ops/test_config_entity.py
@@ -1,16 +1,11 @@
-import pytest
-from pydantic import ValidationError
+from dify_trace_aliyun.config import AliyunConfig
+from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
+from dify_trace_langfuse.config import LangfuseConfig
+from dify_trace_langsmith.config import LangSmithConfig
+from dify_trace_opik.config import OpikConfig
+from dify_trace_weave.config import WeaveConfig
-from core.ops.entities.config_entity import (
- AliyunConfig,
- ArizeConfig,
- LangfuseConfig,
- LangSmithConfig,
- OpikConfig,
- PhoenixConfig,
- TracingProviderEnum,
- WeaveConfig,
-)
+from core.ops.entities.config_entity import TracingProviderEnum
class TestTracingProviderEnum:
@@ -27,349 +22,8 @@ class TestTracingProviderEnum:
assert TracingProviderEnum.ALIYUN == "aliyun"
-class TestArizeConfig:
- """Test cases for ArizeConfig"""
-
- def test_valid_config(self):
- """Test valid Arize configuration"""
- config = ArizeConfig(
- api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
- )
- assert config.api_key == "test_key"
- assert config.space_id == "test_space"
- assert config.project == "test_project"
- assert config.endpoint == "https://custom.arize.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = ArizeConfig()
- assert config.api_key is None
- assert config.space_id is None
- assert config.project is None
- assert config.endpoint == "https://otlp.arize.com"
-
- def test_project_validation_empty(self):
- """Test project validation with empty value"""
- config = ArizeConfig(project="")
- assert config.project == "default"
-
- def test_project_validation_none(self):
- """Test project validation with None value"""
- config = ArizeConfig(project=None)
- assert config.project == "default"
-
- def test_endpoint_validation_empty(self):
- """Test endpoint validation with empty value"""
- config = ArizeConfig(endpoint="")
- assert config.endpoint == "https://otlp.arize.com"
-
- def test_endpoint_validation_with_path(self):
- """Test endpoint validation normalizes URL by removing path"""
- config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
- assert config.endpoint == "https://custom.arize.com"
-
- def test_endpoint_validation_invalid_scheme(self):
- """Test endpoint validation rejects invalid schemes"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- ArizeConfig(endpoint="ftp://invalid.com")
-
- def test_endpoint_validation_no_scheme(self):
- """Test endpoint validation rejects URLs without scheme"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- ArizeConfig(endpoint="invalid.com")
-
-
-class TestPhoenixConfig:
- """Test cases for PhoenixConfig"""
-
- def test_valid_config(self):
- """Test valid Phoenix configuration"""
- config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
- assert config.api_key == "test_key"
- assert config.project == "test_project"
- assert config.endpoint == "https://custom.phoenix.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = PhoenixConfig()
- assert config.api_key is None
- assert config.project is None
- assert config.endpoint == "https://app.phoenix.arize.com"
-
- def test_project_validation_empty(self):
- """Test project validation with empty value"""
- config = PhoenixConfig(project="")
- assert config.project == "default"
-
- def test_endpoint_validation_with_path(self):
- """Test endpoint validation with path"""
- config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
- assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration"
-
- def test_endpoint_validation_without_path(self):
- """Test endpoint validation without path"""
- config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
- assert config.endpoint == "https://app.phoenix.arize.com"
-
-
-class TestLangfuseConfig:
- """Test cases for LangfuseConfig"""
-
- def test_valid_config(self):
- """Test valid Langfuse configuration"""
- config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
- assert config.public_key == "public_key"
- assert config.secret_key == "secret_key"
- assert config.host == "https://custom.langfuse.com"
-
- def test_valid_config_with_path(self):
- host = "https://custom.langfuse.com/api/v1"
- config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host)
- assert config.public_key == "public_key"
- assert config.secret_key == "secret_key"
- assert config.host == host
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = LangfuseConfig(public_key="public", secret_key="secret")
- assert config.host == "https://api.langfuse.com"
-
- def test_missing_required_fields(self):
- """Test that required fields are enforced"""
- with pytest.raises(ValidationError):
- LangfuseConfig()
-
- with pytest.raises(ValidationError):
- LangfuseConfig(public_key="public")
-
- with pytest.raises(ValidationError):
- LangfuseConfig(secret_key="secret")
-
- def test_host_validation_empty(self):
- """Test host validation with empty value"""
- config = LangfuseConfig(public_key="public", secret_key="secret", host="")
- assert config.host == "https://api.langfuse.com"
-
-
-class TestLangSmithConfig:
- """Test cases for LangSmithConfig"""
-
- def test_valid_config(self):
- """Test valid LangSmith configuration"""
- config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
- assert config.api_key == "test_key"
- assert config.project == "test_project"
- assert config.endpoint == "https://custom.smith.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = LangSmithConfig(api_key="key", project="project")
- assert config.endpoint == "https://api.smith.langchain.com"
-
- def test_missing_required_fields(self):
- """Test that required fields are enforced"""
- with pytest.raises(ValidationError):
- LangSmithConfig()
-
- with pytest.raises(ValidationError):
- LangSmithConfig(api_key="key")
-
- with pytest.raises(ValidationError):
- LangSmithConfig(project="project")
-
- def test_endpoint_validation_https_only(self):
- """Test endpoint validation only allows HTTPS"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
-
-
-class TestOpikConfig:
- """Test cases for OpikConfig"""
-
- def test_valid_config(self):
- """Test valid Opik configuration"""
- config = OpikConfig(
- api_key="test_key",
- project="test_project",
- workspace="test_workspace",
- url="https://custom.comet.com/opik/api/",
- )
- assert config.api_key == "test_key"
- assert config.project == "test_project"
- assert config.workspace == "test_workspace"
- assert config.url == "https://custom.comet.com/opik/api/"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = OpikConfig()
- assert config.api_key is None
- assert config.project is None
- assert config.workspace is None
- assert config.url == "https://www.comet.com/opik/api/"
-
- def test_project_validation_empty(self):
- """Test project validation with empty value"""
- config = OpikConfig(project="")
- assert config.project == "Default Project"
-
- def test_url_validation_empty(self):
- """Test URL validation with empty value"""
- config = OpikConfig(url="")
- assert config.url == "https://www.comet.com/opik/api/"
-
- def test_url_validation_missing_suffix(self):
- """Test URL validation requires /api/ suffix"""
- with pytest.raises(ValidationError, match="URL should end with /api/"):
- OpikConfig(url="https://custom.comet.com/opik/")
-
- def test_url_validation_invalid_scheme(self):
- """Test URL validation rejects invalid schemes"""
- with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
- OpikConfig(url="ftp://custom.comet.com/opik/api/")
-
-
-class TestWeaveConfig:
- """Test cases for WeaveConfig"""
-
- def test_valid_config(self):
- """Test valid Weave configuration"""
- config = WeaveConfig(
- api_key="test_key",
- entity="test_entity",
- project="test_project",
- endpoint="https://custom.wandb.ai",
- host="https://custom.host.com",
- )
- assert config.api_key == "test_key"
- assert config.entity == "test_entity"
- assert config.project == "test_project"
- assert config.endpoint == "https://custom.wandb.ai"
- assert config.host == "https://custom.host.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = WeaveConfig(api_key="key", project="project")
- assert config.entity is None
- assert config.endpoint == "https://trace.wandb.ai"
- assert config.host is None
-
- def test_missing_required_fields(self):
- """Test that required fields are enforced"""
- with pytest.raises(ValidationError):
- WeaveConfig()
-
- with pytest.raises(ValidationError):
- WeaveConfig(api_key="key")
-
- with pytest.raises(ValidationError):
- WeaveConfig(project="project")
-
- def test_endpoint_validation_https_only(self):
- """Test endpoint validation only allows HTTPS"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
-
- def test_host_validation_optional(self):
- """Test host validation is optional but validates when provided"""
- config = WeaveConfig(api_key="key", project="project", host=None)
- assert config.host is None
-
- config = WeaveConfig(api_key="key", project="project", host="")
- assert config.host == ""
-
- config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
- assert config.host == "https://valid.host.com"
-
- def test_host_validation_invalid_scheme(self):
- """Test host validation rejects invalid schemes when provided"""
- with pytest.raises(ValidationError, match="URL scheme must be one of"):
- WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
-
-
-class TestAliyunConfig:
- """Test cases for AliyunConfig"""
-
- def test_valid_config(self):
- """Test valid Aliyun configuration"""
- config = AliyunConfig(
- app_name="test_app",
- license_key="test_license_key",
- endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
- )
- assert config.app_name == "test_app"
- assert config.license_key == "test_license_key"
- assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
-
- def test_default_values(self):
- """Test default values are set correctly"""
- config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
- assert config.app_name == "dify_app"
-
- def test_missing_required_fields(self):
- """Test that required fields are enforced"""
- with pytest.raises(ValidationError):
- AliyunConfig()
-
- with pytest.raises(ValidationError):
- AliyunConfig(license_key="test_license")
-
- with pytest.raises(ValidationError):
- AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
-
- def test_app_name_validation_empty(self):
- """Test app_name validation with empty value"""
- config = AliyunConfig(
- license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
- )
- assert config.app_name == "dify_app"
-
- def test_endpoint_validation_empty(self):
- """Test endpoint validation with empty value"""
- config = AliyunConfig(license_key="test_license", endpoint="")
- assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
-
- def test_endpoint_validation_with_path(self):
- """Test endpoint validation preserves path for Aliyun endpoints"""
- config = AliyunConfig(
- license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
- )
- assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
-
- def test_endpoint_validation_invalid_scheme(self):
- """Test endpoint validation rejects invalid schemes"""
- with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
- AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
-
- def test_endpoint_validation_no_scheme(self):
- """Test endpoint validation rejects URLs without scheme"""
- with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
- AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
-
- def test_license_key_required(self):
- """Test that license_key is required and cannot be empty"""
- with pytest.raises(ValidationError):
- AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
-
- def test_valid_endpoint_format_examples(self):
- """Test valid endpoint format examples from comments"""
- valid_endpoints = [
- # cms2.0 public endpoint
- "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry",
- # cms2.0 intranet endpoint
- "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry",
- # xtrace public endpoint
- "http://tracing-cn-heyuan.arms.aliyuncs.com",
- # xtrace intranet endpoint
- "http://tracing-cn-heyuan-internal.arms.aliyuncs.com",
- ]
-
- for endpoint in valid_endpoints:
- config = AliyunConfig(license_key="test_license", endpoint=endpoint)
- assert config.endpoint == endpoint
-
-
class TestConfigIntegration:
- """Integration tests for configuration classes"""
+ """Cross-provider configuration sanity checks"""
def test_all_configs_can_be_instantiated(self):
"""Test that all config classes can be instantiated with valid data"""
@@ -388,7 +42,6 @@ class TestConfigIntegration:
def test_url_normalization_consistency(self):
"""Test that URL normalization works consistently across configs"""
- # Test that paths are removed from endpoints
arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration")
phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com")
diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
index 68aa130518..88bf555594 100644
--- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
+++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py
@@ -56,7 +56,7 @@ class TestPluginModelRuntime:
assert len(providers) == 1
assert providers[0].provider == "langgenius/openai/openai"
assert providers[0].provider_name == "openai"
- assert providers[0].label.en_US == "OpenAI"
+ assert providers[0].label.en_us == "OpenAI"
client.fetch_model_providers.assert_called_once_with("tenant")
def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None:
diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
index d49b6e4b71..00a4207786 100644
--- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
+++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py
@@ -466,7 +466,7 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
file_param = File(
tenant_id="tenant-1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/file.png",
storage_key="",
@@ -499,14 +499,14 @@ class TestConverter:
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
file_one = File(
tenant_id="tenant-1",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/a.txt",
storage_key="",
)
file_two = File(
tenant_id="tenant-1",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/b.txt",
storage_key="",
diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
index 395d392127..e536c0831f 100644
--- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py
@@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
files = [
File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
@@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files():
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query():
memory = MagicMock(spec=TokenBufferMemory)
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
@@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch():
transform = AdvancedPromptTransform()
model_config_mock = MagicMock(spec=ModelConfigEntity)
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
storage_key="",
diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
index 803afa54d7..28966242d8 100644
--- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py
@@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
-from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from models.model import Conversation
diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
index 9f9ea33695..5308c8e7b3 100644
--- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py
+++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py
@@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
# from graphon.model_runtime.entities.message_entities import UserPromptMessage
# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
# from graphon.model_runtime.entities.provider_entities import ProviderEntity
-# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
# from core.prompt.prompt_transform import PromptTransform
diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
index 64eb89590a..0220fb6d4a 100644
--- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
+++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py
@@ -1,12 +1,14 @@
"""Primarily used for testing merged cell scenarios"""
+import gc
import io
import os
import tempfile
+import warnings
from collections import UserDict
from pathlib import Path
from types import SimpleNamespace
-from unittest.mock import MagicMock
+from unittest.mock import AsyncMock, MagicMock
import pytest
from docx import Document
@@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
WordExtractor("not-a-file", "tenant", "user")
-def test_del_closes_temp_file():
+def test_close_closes_temp_file():
extractor = object.__new__(WordExtractor)
+ extractor._closed = False
extractor.temp_file = MagicMock()
- WordExtractor.__del__(extractor)
+ extractor.close()
extractor.temp_file.close.assert_called_once()
+def test_close_is_idempotent():
+ extractor = object.__new__(WordExtractor)
+ extractor._closed = False
+ extractor.temp_file = MagicMock()
+
+ extractor.close()
+ extractor.close()
+
+ extractor.temp_file.close.assert_called_once()
+
+
+def test_close_handles_async_close_mock():
+ extractor = object.__new__(WordExtractor)
+ extractor._closed = False
+ extractor.temp_file = MagicMock()
+ extractor.temp_file.close = AsyncMock()
+
+ with warnings.catch_warnings(record=True) as caught:
+ warnings.simplefilter("always")
+ extractor.close()
+ gc.collect()
+
+ extractor.temp_file.close.assert_called_once()
+ assert not [
+ warning
+ for warning in caught
+ if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
+ ]
+
+
def test_extract_images_handles_invalid_external_cases(monkeypatch):
class FakeTargetRef:
def __contains__(self, item):
diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
index 8be1ac318c..18ae9fafc8 100644
--- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
+++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py
@@ -14,7 +14,7 @@ from core.repositories.human_input_repository import (
HumanInputFormSubmissionRepository,
_WorkspaceMemberInfo,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py
index 1297a95df1..4248782d93 100644
--- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py
+++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py
@@ -21,7 +21,7 @@ from core.repositories.human_input_repository import (
_InvalidTimeoutStatusError,
_WorkspaceMemberInfo,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py
index f17927f16b..eab0176f41 100644
--- a/api/tests/unit_tests/core/test_file.py
+++ b/api/tests/unit_tests/core/test_file.py
@@ -6,9 +6,9 @@ from models.workflow import Workflow
def test_file_to_dict():
file = File(
- id="file1",
+ file_id="file1",
tenant_id="tenant1",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py
index 72052c8c05..9e07ea1b6d 100644
--- a/api/tests/unit_tests/core/variables/test_segment.py
+++ b/api/tests/unit_tests/core/variables/test_segment.py
@@ -1,8 +1,9 @@
import dataclasses
+from typing import Annotated
import orjson
import pytest
-from pydantic import BaseModel
+from pydantic import BaseModel, Discriminator, Tag
from core.helper import encrypter
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
@@ -12,17 +13,18 @@ from graphon.runtime import VariablePool
from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import (
ArrayAnySegment,
+ ArrayBooleanSegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
+ BooleanSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
- SegmentUnion,
StringSegment,
get_segment_discriminator,
)
@@ -47,6 +49,26 @@ from graphon.variables.variables import (
StringVariable,
Variable,
)
+from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup
+
+type SegmentUnion = Annotated[
+ (
+ Annotated[NoneSegment, Tag(SegmentType.NONE)]
+ | Annotated[StringSegment, Tag(SegmentType.STRING)]
+ | Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
+ | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
+ | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
+ | Annotated[FileSegment, Tag(SegmentType.FILE)]
+ | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
+ | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
+ | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
+ | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
+ | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
+ | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
+ | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
+ ),
+ Discriminator(get_segment_discriminator),
+]
def _build_variable_pool(
@@ -123,7 +145,7 @@ def create_test_file(
) -> File:
"""Factory function to create File objects for testing"""
return File(
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
@@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad:
assert restored == model
def test_all_segments_serialization(self):
- """Test serialization/deserialization of all segment types"""
+ """Test file-aware segment serialization through Dify's model boundary."""
# Create one instance of each segment type
test_file = create_test_file()
@@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
- loaded = _Segments.model_validate_json(json_str)
+ loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
@@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
- """Test serialization/deserialization of all variable types"""
+ """Test file-aware variable serialization through Dify's model boundary."""
# Create one instance of each variable type
test_file = create_test_file()
@@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad:
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
- loaded = _Variables.model_validate_json(json_str)
+ loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py
index 94e788edb2..317fe99d37 100644
--- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py
+++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py
@@ -35,7 +35,7 @@ def create_test_file(
"""Factory function to create File objects for testing."""
return File(
tenant_id="test-tenant",
- type=file_type,
+ file_type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
index 76b2984a4b..9f3e3b00b9 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py
@@ -1,12 +1,13 @@
-"""
-Mock node factory for testing workflows with third-party service dependencies.
+"""Mock node factory for third-party-service workflow tests.
-This module provides a MockNodeFactory that automatically detects and mocks nodes
-requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
+The factory follows the same config adaptation path as production
+`DifyNodeFactory.create_node()`, but swaps selected node classes for mock
+implementations before instantiation.
"""
from typing import TYPE_CHECKING, Any
+from core.workflow.human_input_adapter import adapt_node_config_for_graph
from core.workflow.node_factory import DifyNodeFactory
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes, NodeType
@@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory):
:param node_config: Node configuration dictionary
:return: Node instance (real or mocked)
"""
- typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
+ typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
+ node_id = typed_node_config["id"]
node_data = typed_node_config["data"]
node_type = node_data.type
# Check if this node type should be mocked
if node_type in self._mock_node_types:
- node_id = typed_node_config["id"]
-
# Create mock node instance
mock_class = self._mock_node_types[node_type]
+ resolved_node_data = self._validate_resolved_node_data(mock_class, node_data)
if node_type == BuiltinNodeTypes.CODE:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory):
)
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory):
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
}:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory):
)
else:
mock_instance = mock_class(
- id=node_id,
- config=typed_node_config,
+ node_id=node_id,
+ config=resolved_node_data,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
@@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory):
return mock_instance
# For non-mocked node types, use parent implementation
- return super().create_node(typed_node_config)
+ return super().create_node(node_config)
def should_mock_node(self, node_type: NodeType) -> bool:
"""
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
index 971b9b2bbf..f9819c47ec 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py
@@ -55,13 +55,14 @@ class MockNodeMixin:
def __init__(
self,
- id: str,
- config: Mapping[str, Any],
+ node_id: str,
+ config: Any,
+ *,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
- ):
+ ) -> None:
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
@@ -96,7 +97,7 @@ class MockNodeMixin:
kwargs.setdefault("message_transformer", MagicMock())
super().__init__(
- id=id,
+ node_id=node_id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
index 55a329eba9..75bc6d05f7 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py
@@ -139,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
- id=start_config["id"],
- config=start_config,
+ node_id=start_config["id"],
+ config=StartNodeData(title="Start", variables=[]),
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -154,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
- id=human_a_config["id"],
- config=human_a_config,
+ node_id=human_a_config["id"],
+ config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -164,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
- id=human_b_config["id"],
- config=human_b_config,
+ node_id=human_b_config["id"],
+ config=human_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
@@ -182,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
- id=end_config["id"],
- config=end_config,
+ node_id=end_config["id"],
+ config=end_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
index 9c0ad25b58..76b4cd1ef4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
@@ -9,6 +9,7 @@ from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.graph import Graph
from graphon.nodes.answer.answer_node import AnswerNode
+from graphon.nodes.answer.entities import AnswerNodeData
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,20 +67,15 @@ def test_execute_answer():
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "answer",
- "data": {
- "title": "123",
- "type": "answer",
- "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
- },
- }
-
node = AnswerNode(
- id=str(uuid.uuid4()),
+ node_id=str(uuid.uuid4()),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
+ config=AnswerNodeData(
+ title="123",
+ type="answer",
+ answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+ ),
)
# Mock db.session.close()
diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
index 9cceadde49..d7ef781732 100644
--- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py
@@ -1,5 +1,6 @@
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
+from core.workflow.nodes.datasource.entities import DatasourceNodeData
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@@ -77,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker):
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
node = DatasourceNode(
- id="n",
- config={
- "id": "n",
- "data": {
- "type": "datasource",
- "version": "1",
- "title": "Datasource",
- "provider_type": "plugin",
- "provider_name": "p",
- "plugin_id": "plug",
- "datasource_name": "ds",
- },
- },
+ node_id="n",
+ config=DatasourceNodeData(
+ type="datasource",
+ version="1",
+ title="Datasource",
+ provider_type="plugin",
+ provider_name="p",
+ plugin_id="plug",
+ datasource_name="ds",
+ ),
graph_init_params=gp,
graph_runtime_state=gs,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
index a3cadc0681..2e89a2da3c 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
@@ -12,7 +12,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.file.file_manager import file_manager
from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig
-from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response
+from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
@@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config():
assert default_config["retry_config"]["max_retries"] == 7
-def test_get_default_config_with_malformed_http_request_config_raises_value_error():
- with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"):
+def test_get_default_config_with_malformed_http_request_config_raises_type_error():
+ with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"):
HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"})
@@ -114,8 +114,8 @@ def _build_http_node(
start_at=time.perf_counter(),
)
return HttpRequestNode(
- id="http-node",
- config=node_config,
+ node_id="http-node",
+ config=HttpRequestNodeData.model_validate(node_data),
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
http_request_config=HTTP_REQUEST_CONFIG,
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
index 1d6a4da7c4..07430498e5 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py
@@ -1,4 +1,4 @@
-from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients
+from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients
from graphon.runtime import VariablePool
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
index c0e21d0bf7..0659984c76 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py
@@ -19,7 +19,7 @@ from core.repositories.human_input_repository import (
HumanInputFormRecipientEntity,
HumanInputFormRepository,
)
-from core.workflow.human_input_compat import (
+from core.workflow.human_input_adapter import (
DeliveryMethodType,
EmailDeliveryConfig,
EmailDeliveryMethod,
@@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository):
entity.status_value = HumanInputFormStatus.SUBMITTED
+def _build_human_input_node(
+ *,
+ node_id: str,
+ node_data: HumanInputNodeData | Mapping[str, Any],
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
+ runtime: DifyHumanInputNodeRuntime,
+) -> HumanInputNode:
+ typed_node_data = (
+ node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data)
+ )
+ return HumanInputNode(
+ node_id=node_id,
+ config=typed_node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ runtime=runtime,
+ )
+
+
class TestDeliveryMethod:
"""Test DeliveryMethod entity."""
@@ -239,7 +259,7 @@ class TestUserAction:
data[field_name] = value
with pytest.raises(ValidationError) as exc_info:
- UserAction(**data)
+ UserAction.model_validate(data)
errors = exc_info.value.errors()
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
@@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution:
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
@@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent:
form_repository = InMemoryHumanInputFormRepository()
runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context)
runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined]
- node = HumanInputNode(
- id=config["id"],
- config=config,
+ node = _build_human_input_node(
+ node_id=config["id"],
+ node_data=config["data"],
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
runtime=runtime,
diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
index bc98028d5b..4a9438b14f 100644
--- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
+++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py
@@ -11,6 +11,7 @@ from graphon.graph_events import (
NodeRunHumanInputFormTimeoutEvent,
NodeRunStartedEvent,
)
+from graphon.nodes.human_input.entities import HumanInputNodeData
from graphon.nodes.human_input.enums import HumanInputFormStatus
from graphon.nodes.human_input.human_input_node import HumanInputNode
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -25,6 +26,28 @@ class _FakeFormRepository:
return self._form
+def _create_human_input_node(
+ *,
+ config: dict,
+ graph_init_params: GraphInitParams,
+ graph_runtime_state: GraphRuntimeState,
+ repo: _FakeFormRepository,
+) -> HumanInputNode:
+ node_data = (
+ config["data"]
+ if isinstance(config["data"], HumanInputNodeData)
+ else HumanInputNodeData.model_validate(config["data"])
+ )
+ return HumanInputNode(
+ node_id=config["id"],
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ form_repository=repo,
+ runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ )
+
+
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = default_system_variables()
graph_runtime_state = GraphRuntimeState(
@@ -80,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#
)
repo = _FakeFormRepository(fake_form)
- return HumanInputNode(
- id="node-1",
+ return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
- form_repository=repo,
- runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ repo=repo,
)
@@ -145,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode:
)
repo = _FakeFormRepository(fake_form)
- return HumanInputNode(
- id="node-1",
+ return _create_human_input_node(
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
- form_repository=repo,
- runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
+ repo=repo,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
index 82cc734274..8ffce39cd6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py
@@ -5,6 +5,7 @@ import pytest
from core.workflow.system_variables import default_system_variables
from graphon.entities import GraphInitParams
+from graphon.nodes.iteration.entities import IterationNodeData
from graphon.nodes.iteration.exc import IterationGraphNotFoundError
from graphon.nodes.iteration.iteration_node import IterationNode
from graphon.runtime import (
@@ -44,17 +45,14 @@ def _build_iteration_node(
) -> IterationNode:
init_params = build_test_graph_init_params(graph_config=graph_config)
return IterationNode(
- id="iteration-node",
- config={
- "id": "iteration-node",
- "data": {
- "type": "iteration",
- "title": "Iteration",
- "iterator_selector": ["start", "items"],
- "output_selector": ["iteration-node", "output"],
- "start_node_id": start_node_id,
- },
- },
+ node_id="iteration-node",
+ config=IterationNodeData(
+ type="iteration",
+ title="Iteration",
+ iterator_selector=["start", "items"],
+ output_selector=["iteration-node", "output"],
+ start_node_id=start_node_id,
+ ),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
index a6fca1bfb4..f254fc3d09 100644
--- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py
@@ -93,6 +93,25 @@ def sample_chunks():
}
+def _build_node(
+ *,
+ node_id: str,
+ node_data: KnowledgeIndexNodeData | dict[str, object],
+ graph_init_params,
+ graph_runtime_state,
+) -> KnowledgeIndexNode:
+ return KnowledgeIndexNode(
+ node_id=node_id,
+ config=(
+ node_data
+ if isinstance(node_data, KnowledgeIndexNodeData)
+ else KnowledgeIndexNodeData.model_validate(node_data)
+ ),
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+
class TestKnowledgeIndexNode:
"""
Test suite for KnowledgeIndexNode.
@@ -115,9 +134,9 @@ class TestKnowledgeIndexNode:
}
# Act
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -143,9 +162,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -176,9 +195,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -212,9 +231,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -269,9 +288,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -332,9 +351,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -383,9 +402,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -440,9 +459,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -498,9 +517,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -536,9 +555,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -583,9 +602,9 @@ class TestKnowledgeIndexNode:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex:
"data": sample_node_data.model_dump(),
}
- node = KnowledgeIndexNode(
- id=node_id,
- config=config,
+ node = _build_node(
+ node_id=node_id,
+ node_data=config["data"],
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
index 45e8ae7d20..e923ee761b 100644
--- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py
@@ -14,7 +14,11 @@ from core.workflow.nodes.knowledge_retrieval.entities import (
SingleRetrievalConfig,
)
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
-from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
+from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import (
+ KnowledgeRetrievalNode,
+ _normalize_metadata_filter_scalar,
+ _normalize_metadata_filter_sequence_item,
+)
from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowNodeExecutionStatus
@@ -85,6 +89,12 @@ def sample_node_data():
)
+def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None:
+ assert _normalize_metadata_filter_scalar(3) == 3
+ assert _normalize_metadata_filter_scalar(True) == "True"
+ assert _normalize_metadata_filter_sequence_item(4) == "4"
+
+
class TestKnowledgeRetrievalNode:
"""
Test suite for KnowledgeRetrievalNode.
@@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode:
# Act
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode:
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -470,8 +480,8 @@ class TestFetchDatasetRetriever:
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -507,8 +517,8 @@ class TestFetchDatasetRetriever:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -562,8 +572,8 @@ class TestFetchDatasetRetriever:
}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -610,8 +620,8 @@ class TestFetchDatasetRetriever:
mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme"))
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -671,8 +681,8 @@ class TestFetchDatasetRetriever:
node_id = str(uuid.uuid4())
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
- id=node_id,
- config=config,
+ node_id=node_id,
+ config=KnowledgeRetrievalNodeData.model_validate(config["data"]),
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
index eca34f05be..388654f279 100644
--- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
+++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py
@@ -1,3 +1,4 @@
+from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@@ -5,6 +6,7 @@ import pytest
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
from graphon.entities import GraphInitParams
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
+from graphon.nodes.list_operator.entities import ListOperatorNodeData
from graphon.nodes.list_operator.node import ListOperatorNode
from graphon.runtime import GraphRuntimeState
from graphon.variables import ArrayNumberSegment, ArrayStringSegment
@@ -13,11 +15,28 @@ from graphon.variables import ArrayNumberSegment, ArrayStringSegment
class TestListOperatorNode:
"""Comprehensive tests for ListOperatorNode."""
+ @staticmethod
+ def _build_node(*, config, graph_init_params, graph_runtime_state):
+ return ListOperatorNode(
+ node_id="test",
+ config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config),
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+ @staticmethod
+ def _filter_by(comparison_operator: str, value: str) -> dict[str, object]:
+ return {
+ "enabled": True,
+ "conditions": [{"comparison_operator": comparison_operator, "value": value}],
+ }
+
@pytest.fixture
def mock_graph_runtime_state(self):
"""Create mock GraphRuntimeState."""
mock_state = MagicMock(spec=GraphRuntimeState)
mock_variable_pool = MagicMock()
+ mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value)
mock_state.variable_pool = mock_variable_pool
return mock_state
@@ -45,9 +64,8 @@ class TestListOperatorNode:
def _create_node(config, mock_variable):
mock_graph_runtime_state.variable_pool.get.return_value = mock_variable
- return ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ return self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -64,9 +82,8 @@ class TestListOperatorNode:
"limit": {"enabled": False},
}
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -109,9 +126,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=[])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -128,11 +144,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "contains",
- "value": "app",
- },
+ "filter_by": self._filter_by("contains", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -140,9 +152,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -157,11 +168,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "not contains",
- "value": "app",
- },
+ "filter_by": self._filter_by("not contains", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -169,9 +176,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -186,11 +192,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": ">",
- "value": "5",
- },
+ "filter_by": self._filter_by(">", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -198,9 +200,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -226,9 +227,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -254,9 +254,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -282,9 +281,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -299,11 +297,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": ">",
- "value": "3",
- },
+ "filter_by": self._filter_by(">", "3"),
"order_by": {
"enabled": True,
"value": "desc",
@@ -317,9 +311,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -341,9 +334,8 @@ class TestListOperatorNode:
mock_graph_runtime_state.variable_pool.get.return_value = None
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -366,9 +358,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["first", "middle", "last"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -384,11 +375,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "start with",
- "value": "app",
- },
+ "filter_by": self._filter_by("start with", "app"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -396,9 +383,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -413,11 +399,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "items"],
- "filter_by": {
- "enabled": True,
- "condition": "end with",
- "value": "le",
- },
+ "filter_by": self._filter_by("end with", "le"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -425,9 +407,8 @@ class TestListOperatorNode:
mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -442,11 +423,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": "=",
- "value": "5",
- },
+ "filter_by": self._filter_by("=", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -454,9 +431,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -471,11 +447,7 @@ class TestListOperatorNode:
config = {
"title": "Test",
"variable": ["sys", "numbers"],
- "filter_by": {
- "enabled": True,
- "condition": "≠",
- "value": "5",
- },
+ "filter_by": self._filter_by("≠", "5"),
"order_by": {"enabled": False},
"limit": {"enabled": False},
}
@@ -483,9 +455,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
@@ -511,9 +482,8 @@ class TestListOperatorNode:
mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5])
mock_graph_runtime_state.variable_pool.get.return_value = mock_var
- node = ListOperatorNode(
- id="test",
- config={"id": "test", "data": config},
+ node = self._build_node(
+ config=config,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
index 4186bbdc93..212ad07bd3 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
@@ -71,8 +71,8 @@ def _build_image_file(
mime_type: str = "image/png",
) -> File:
return File(
- id=file_id,
- type=FileType.IMAGE,
+ file_id=file_id,
+ file_type=FileType.IMAGE,
filename=f"{file_id}{extension}",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=remote_url,
@@ -95,6 +95,8 @@ def variable_pool() -> VariablePool:
def _fetch_prompt_messages_with_mocked_content(content):
variable_pool = VariablePool.empty()
model_instance = mock.MagicMock(spec=ModelInstance)
+ model_schema = mock.MagicMock()
+ model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text"
prompt_template = [
LLMNodeChatModelMessage(
text="You are a classifier.",
@@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content):
with (
mock.patch(
"graphon.nodes.llm.llm_utils.fetch_model_schema",
- return_value=mock.MagicMock(features=[]),
+ return_value=model_schema,
),
mock.patch(
"graphon.nodes.llm.llm_utils.handle_list_messages",
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
index b1f81b6c48..c707cf28cd 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
@@ -140,8 +140,8 @@ def _build_image_file(
mime_type: str = "image/png",
) -> File:
return File(
- id=file_id,
- type=FileType.IMAGE,
+ file_id=file_id,
+ file_type=FileType.IMAGE,
filename=f"{file_id}{extension}",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=remote_url,
@@ -205,14 +205,10 @@ def llm_node(
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol)
- node_config = {
- "id": "1",
- "data": llm_node_data.model_dump(),
- }
http_client = mock.MagicMock()
node = LLMNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers():
def test_fetch_files_with_file_segment():
file = File(
- id="1",
- type=FileType.IMAGE,
+ file_id="1",
+ file_type=FileType.IMAGE,
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
@@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment():
def test_fetch_files_with_array_file_segment():
files = [
File(
- id="1",
- type=FileType.IMAGE,
+ file_id="1",
+ file_type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
),
File(
- id="2",
- type=FileType.IMAGE,
+ file_id="2",
+ file_type=FileType.IMAGE,
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
@@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol)
- node_config = {
- "id": "1",
- "data": llm_node_data.model_dump(),
- }
http_client = mock.MagicMock()
node = LLMNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=llm_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
@@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/png",
)
mock_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
@@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput:
mime_type="image/jpg",
)
mock_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
@@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
image_b64_data = base64.b64encode(image_raw_data).decode()
mock_saved_file = File(
- id=str(uuid.uuid4()),
- type=FileType.IMAGE,
+ file_id=str(uuid.uuid4()),
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
@@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable
file_saver=file_saver,
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
reasoning_format="separated",
)
)
@@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out
file_saver=mock.MagicMock(spec=LLMFileSaver),
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
model_instance=_build_prepared_llm_mock(),
reasoning_format="separated",
request_start_time=1.0,
@@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors():
file_saver=mock.MagicMock(spec=LLMFileSaver),
file_outputs=[],
node_id="node-1",
- node_type=LLMNode.node_type,
model_instance=model_instance,
)
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
index bc44ececd8..892f6cc586 100644
--- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py
@@ -13,6 +13,28 @@ from graphon.template_rendering import TemplateRenderError
from tests.workflow_test_utils import build_test_graph_init_params
+def _build_template_transform_node(
+ *,
+ node_data,
+ graph_init_params,
+ graph_runtime_state,
+ node_id: str = "test_node",
+ **kwargs,
+) -> TemplateTransformNode:
+ typed_node_data = (
+ node_data
+ if isinstance(node_data, TemplateTransformNodeData)
+ else TemplateTransformNodeData.model_validate(node_data)
+ )
+ return TemplateTransformNode(
+ node_id=node_id,
+ config=typed_node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ **kwargs,
+ )
+
+
class TestTemplateTransformNode:
"""Comprehensive test suite for TemplateTransformNode."""
@@ -59,9 +81,8 @@ class TestTemplateTransformNode:
def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test that TemplateTransformNode initializes correctly."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -75,9 +96,8 @@ class TestTemplateTransformNode:
def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_title method."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -88,9 +108,8 @@ class TestTemplateTransformNode:
def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params):
"""Test _get_description method."""
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -108,9 +127,8 @@ class TestTemplateTransformNode:
}
mock_renderer = MagicMock()
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -143,9 +161,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
with pytest.raises(ValueError, match="max_output_length must be a positive integer"):
- TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -170,9 +187,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -198,9 +214,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Value: "
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -218,9 +233,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error")
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -238,9 +252,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -260,9 +273,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "1234567890"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": basic_node_data},
+ node = _build_template_transform_node(
+ node_data=basic_node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -302,9 +314,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -375,8 +386,8 @@ class TestTemplateTransformNode:
)
assert mapping == {
- "node_123.var1": ["sys", "input1"],
- "node_123.empty_selector": [],
+ "node_123.var1": ("sys", "input1"),
+ "node_123.empty_selector": (),
}
def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self):
@@ -409,9 +420,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "This is a static message."
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -448,9 +458,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Total: $31.5"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -477,9 +486,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com"
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
@@ -507,9 +515,8 @@ class TestTemplateTransformNode:
mock_renderer = MagicMock()
mock_renderer.render_template.return_value = "Tags: #python #ai #workflow "
- node = TemplateTransformNode(
- id="test_node",
- config={"id": "test_node", "data": node_data},
+ node = _build_template_transform_node(
+ node_data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=mock_renderer,
diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
index 636237e56e..a846efbb43 100644
--- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py
@@ -4,6 +4,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from graphon.nodes.base.entities import VariableSelector
+from graphon.nodes.template_transform.entities import TemplateTransformNodeData
from graphon.nodes.template_transform.template_transform_node import (
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH,
TemplateTransformNode,
@@ -37,15 +38,13 @@ def mock_graph_runtime_state():
def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state):
node = TemplateTransformNode(
- id="test_node",
- config={
- "id": "test_node",
- "data": {
- "title": "Template Transform",
- "variables": [],
- "template": "hello",
- },
- },
+ node_id="test_node",
+ config=TemplateTransformNodeData(
+ title="Template Transform",
+ type="template-transform",
+ variables=[],
+ template="hello",
+ ),
graph_init_params=graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
jinja2_template_renderer=MagicMock(),
@@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri
assert mapping == {
"node_123.validated": ["sys", "input1"],
- "node_123.raw": ["sys", "input2"],
+ "node_123.raw": ("sys", "input2"),
}
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
index 0522dd9d14..364408ead6 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py
@@ -7,7 +7,6 @@ from core.workflow.node_runtime import resolve_dify_run_context
from core.workflow.system_variables import build_system_variables
from graphon.entities import GraphInitParams
from graphon.entities.base_node_data import BaseNodeData
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import BuiltinNodeTypes
from graphon.nodes.base.node import Node
from graphon.runtime import GraphRuntimeState, VariablePool
@@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
-def _build_node_config() -> NodeConfigDict:
- return NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- },
- }
- )
+def _build_node_config() -> dict[str, object]:
+ return {
+ "id": "node-1",
+ "data": _SampleNodeData(
+ type=BuiltinNodeTypes.ANSWER,
+ title="Sample",
+ foo="bar",
+ ),
+ }
+
+
+def _build_node_data() -> _SampleNodeData:
+ return _build_node_config()["data"] # type: ignore[return-value]
def test_node_hydrates_data_during_initialization():
@@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization():
init_params, runtime_state = _build_context(graph_config)
node = _SampleNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum():
)
node = _SampleNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error():
def test_base_node_data_keeps_dict_style_access_compatibility():
- node_data = _SampleNodeData.model_validate(
- {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- }
- )
+ node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar")
assert node_data["foo"] == "bar"
assert node_data.get("foo") == "bar"
@@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility():
def test_node_hydration_preserves_compatibility_extra_fields():
graph_config: dict[str, object] = {}
init_params, runtime_state = _build_context(graph_config)
- node_config = NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": BuiltinNodeTypes.ANSWER,
- "title": "Sample",
- "foo": "bar",
- "compat_flag": True,
- },
- }
- )
+ node_config = {
+ "id": "node-1",
+ "data": _SampleNodeData(
+ type=BuiltinNodeTypes.ANSWER,
+ title="Sample",
+ foo="bar",
+ compat_flag=True,
+ ),
+ }
node = _SampleNode(
- id="node-1",
- config=node_config,
+ node_id="node-1",
+ config=node_config["data"],
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
index 87ec2d5bce..dd75b32593 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
@@ -11,14 +11,16 @@ from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.file import File, FileTransferMethod
from graphon.node_events import NodeRunResult
from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData
+from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError
from graphon.nodes.document_extractor.node import (
_extract_text_from_docx,
_extract_text_from_excel,
+ _extract_text_from_file,
_extract_text_from_pdf,
_extract_text_from_plain_text,
_normalize_docx_zip,
)
-from graphon.variables import ArrayFileSegment
+from graphon.variables import ArrayFileSegment, FileSegment
from graphon.variables.segments import ArrayStringSegment
from graphon.variables.variables import StringVariable
from tests.workflow_test_utils import build_test_graph_init_params
@@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params):
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
- node_config = {"id": "test_node_id", "data": node_data.model_dump()}
http_client = Mock()
node = DocumentExtractorNode(
- id="test_node_id",
- config=node_config,
+ node_id="test_node_id",
+ config=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
http_client=http_client,
@@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file):
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"]
- mock_excel_instance.parse.side_effect = [df, Exception("Parse error")]
+ mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_mixed_content"
@@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
# Mock ExcelFile
mock_excel_instance = Mock()
mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"]
- mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")]
+ mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")]
mock_excel_file.return_value = mock_excel_instance
file_content = b"fake_excel_all_bad_sheets"
@@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file):
assert mock_excel_instance.parse.call_count == 2
+@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook"))
+def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file):
+ with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"):
+ _extract_text_from_excel(b"broken")
+
+
@patch("pandas.ExcelFile")
def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
"""Test extracting text from Excel file with numeric column names."""
@@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
assert expected_manual == result
+@pytest.mark.parametrize(
+ ("extension", "mime_type"),
+ [
+ (".xlsx", "text/plain"),
+ (None, "application/vnd.ms-excel"),
+ ],
+)
+def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type):
+ file = Mock(spec=File)
+ file.extension = extension
+ file.mime_type = mime_type
+
+ with (
+ patch(
+ "graphon.nodes.document_extractor.node._download_file_content",
+ return_value=b"excel",
+ ),
+ patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_excel",
+ return_value="excel text",
+ ) as mock_extract,
+ ):
+ result = _extract_text_from_file(
+ document_extractor_node.http_client,
+ file,
+ unstructured_api_config=document_extractor_node._unstructured_api_config,
+ )
+
+ assert result == "excel text"
+ mock_extract.assert_called_once_with(b"excel")
+
+
+def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node):
+ file = Mock(spec=File)
+ file.extension = None
+ file.mime_type = None
+
+ with patch(
+ "graphon.nodes.document_extractor.node._download_file_content",
+ return_value=b"unknown",
+ ):
+ with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"):
+ _extract_text_from_file(
+ document_extractor_node.http_client,
+ file,
+ unstructured_api_config=document_extractor_node._unstructured_api_config,
+ )
+
+
+def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_list = Mock(spec=ArrayFileSegment)
+ file_list.value = [Mock(spec=File)]
+ mock_graph_runtime_state.variable_pool.get.return_value = file_list
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ side_effect=TextExtractionError("bad file"),
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "bad file"
+
+
+def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_segment = Mock(spec=FileSegment)
+ file_segment.value = Mock(spec=File)
+ mock_graph_runtime_state.variable_pool.get.return_value = file_segment
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ side_effect=TextExtractionError("single file failed"),
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "single file failed"
+
+
+def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state):
+ document_extractor_node.graph_runtime_state = mock_graph_runtime_state
+ file_segment = Mock(spec=FileSegment)
+ file_segment.value = Mock(spec=File)
+ mock_graph_runtime_state.variable_pool.get.return_value = file_segment
+
+ with patch(
+ "graphon.nodes.document_extractor.node._extract_text_from_file",
+ return_value="single file text",
+ ):
+ result = document_extractor_node._run()
+
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs == {"text": "single file text"}
+
+
def _make_docx_zip(use_backslash: bool) -> bytes:
"""Helper to build a minimal in-memory DOCX zip.
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
index 782750e02e..aa9a1360b0 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
@@ -19,6 +19,20 @@ from graphon.variables import ArrayFileSegment
from tests.workflow_test_utils import build_test_graph_init_params
+def _build_if_else_node(
+ *,
+ node_data: IfElseNodeData | dict[str, object],
+ init_params,
+ graph_runtime_state,
+) -> IfElseNode:
+ return IfElseNode(
+ node_id=str(uuid.uuid4()),
+ graph_init_params=init_params,
+ graph_runtime_state=graph_runtime_state,
+ config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data),
+ )
+
+
def test_execute_if_else_result_true():
graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]}
@@ -61,9 +75,8 @@ def test_execute_if_else_result_true():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "if-else",
- "data": {
+ node = _build_if_else_node(
+ node_data={
"title": "123",
"type": "if-else",
"logical_operator": "and",
@@ -104,13 +117,8 @@ def test_execute_if_else_result_true():
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
- }
-
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
)
# Mock db.session.close()
@@ -155,9 +163,8 @@ def test_execute_if_else_result_false():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
- node_config = {
- "id": "if-else",
- "data": {
+ node = _build_if_else_node(
+ node_data={
"title": "123",
"type": "if-else",
"logical_operator": "or",
@@ -174,13 +181,8 @@ def test_execute_if_else_result_false():
},
],
},
- }
-
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config=node_config,
)
# Mock db.session.close()
@@ -222,11 +224,6 @@ def test_array_file_contains_file_name():
],
)
- node_config = {
- "id": "if-else",
- "data": node_data.model_dump(),
- }
-
# Create properly configured mock for graph_init_params
graph_init_params = Mock()
graph_init_params.workflow_id = "test_workflow"
@@ -242,17 +239,12 @@ def test_array_file_contains_file_name():
}
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=graph_init_params,
- graph_runtime_state=Mock(),
- config=node_config,
- )
+ node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock())
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
@@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
"logical_operator": "and",
"conditions": [condition.model_dump()],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={"id": "if-else", "data": node_data},
)
# Mock db.session.close()
@@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions():
],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={
- "id": "if-else",
- "data": node_data,
- },
)
# Mock db.session.close()
@@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure():
}
],
}
- node = IfElseNode(
- id=str(uuid.uuid4()),
- graph_init_params=init_params,
+ node = _build_if_else_node(
+ node_data=node_data,
+ init_params=init_params,
graph_runtime_state=graph_runtime_state,
- config={"id": "if-else", "data": node_data},
)
# Mock db.session.close()
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
index b217e4e8e7..465a4c0ff4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
@@ -19,6 +19,15 @@ from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract
from graphon.variables import ArrayFileSegment
+def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode:
+ return ListOperatorNode(
+ node_id="test_node_id",
+ config=node_data,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=MagicMock(),
+ )
+
+
@pytest.fixture
def list_operator_node():
config = {
@@ -35,10 +44,6 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData.model_validate(config)
- node_config = {
- "id": "test_node_id",
- "data": node_data.model_dump(),
- }
# Create properly configured mock for graph_init_params
graph_init_params = MagicMock()
graph_init_params.workflow_id = "test_workflow"
@@ -54,12 +59,7 @@ def list_operator_node():
}
}
- node = ListOperatorNode(
- id="test_node_id",
- config=node_config,
- graph_init_params=graph_init_params,
- graph_runtime_state=MagicMock(),
- )
+ node = _build_list_operator_node(node_data, graph_init_params)
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node
@@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node):
files = [
File(
filename="image1.jpg",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
storage_key="",
),
File(
filename="document1.pdf",
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
storage_key="",
),
File(
filename="image2.png",
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
storage_key="",
),
File(
filename="audio1.mp3",
- type=FileType.AUDIO,
+ file_type=FileType.AUDIO,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
storage_key="",
@@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node):
def test_get_file_extract_string_func():
# Create a File object
file = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename="test_file.txt",
extension=".txt",
@@ -156,7 +156,7 @@ def test_get_file_extract_string_func():
# Test with empty values
empty_file = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
filename=None,
extension=None,
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
index 543f9878de..5655f80737 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py
@@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables):
inputs=user_inputs,
)
- config = {
- "id": "start",
- "data": StartNodeData(title="Start", variables=variables).model_dump(),
- }
+ node_data = StartNodeData(title="Start", variables=variables)
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
@@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables):
)
return StartNode(
- id="start",
- config=config,
+ node_id="start",
+ config=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
@@ -109,7 +106,7 @@ def test_json_object_invalid_json_string():
node = make_start_node(user_inputs, variables)
- with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"):
+ with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"):
node._run()
@@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot():
inputs={"profile": {"age": 20, "name": "Tom"}},
)
- config = {
- "id": "start",
- "data": StartNodeData(
- title="Start",
- variables=[
- VariableEntity(
- variable="profile",
- label="profile",
- type=VariableEntityType.JSON_OBJECT,
- required=True,
- )
- ],
- ).model_dump(),
- }
+ node_data = StartNodeData(
+ title="Start",
+ variables=[
+ VariableEntity(
+ variable="profile",
+ label="profile",
+ type=VariableEntityType.JSON_OBJECT,
+ required=True,
+ )
+ ],
+ )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node = StartNode(
- id="start",
- config=config,
+ node_id="start",
+ config=node_data,
graph_init_params=build_test_graph_init_params(
workflow_id="wf",
graph_config={},
diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
index c806181340..284af68319 100644
--- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
@@ -13,6 +13,7 @@ from core.workflow.system_variables import build_system_variables
from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import StreamChunkEvent, StreamCompletedEvent
+from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variables.segments import ArrayFileSegment
@@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode:
runtime = _StubToolRuntime()
node = ToolNode(
- id="node-instance",
- config=config,
+ node_id="node-instance",
+ config=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
@@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode:
return node
-def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
+def _collect_events(generator: Generator) -> list[Any]:
events: list[Any] = []
try:
while True:
events.append(next(generator))
- except StopIteration as stop:
- return events, stop.value
+ except StopIteration:
+ return events
def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]:
@@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li
node_id=tool_node._node_id,
tool_runtime=ToolRuntimeHandle(raw=object()),
)
- return _collect_events(generator)
+ events = _collect_events(generator)
+ completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
+ assert completed_events
+ return events, completed_events[-1].node_run_result.llm_usage
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
file_obj = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
@@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode):
def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode):
file_obj = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="file-id",
filename="demo.pdf",
diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
index c8ddc53284..e3b5e3b591 100644
--- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py
@@ -1,10 +1,10 @@
from collections.abc import Mapping
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
+from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
from core.workflow.system_variables import build_system_variables
from graphon.entities import GraphInitParams
-from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.runtime import GraphRuntimeState
from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool
@@ -27,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams,
return init_params, runtime_state
-def _build_node_config() -> NodeConfigDict:
- return NodeConfigDictAdapter.validate_python(
- {
- "id": "node-1",
- "data": {
- "type": TRIGGER_PLUGIN_NODE_TYPE,
- "title": "Trigger Event",
- "plugin_id": "plugin-id",
- "provider_id": "provider-id",
- "event_name": "event-name",
- "subscription_id": "subscription-id",
- "plugin_unique_identifier": "plugin-unique-identifier",
- "event_parameters": {},
- },
- }
+def _build_node_data() -> TriggerEventNodeData:
+ return TriggerEventNodeData(
+ type=TRIGGER_PLUGIN_NODE_TYPE,
+ title="Trigger Event",
+ plugin_id="plugin-id",
+ provider_id="provider-id",
+ event_name="event-name",
+ subscription_id="subscription-id",
+ plugin_unique_identifier="plugin-unique-identifier",
+ event_parameters={},
)
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
init_params, runtime_state = _build_context(graph_config={})
node = TriggerEventNode(
- id="node-1",
- config=_build_node_config(),
+ node_id="node-1",
+ config=_build_node_data(),
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
index 1bbc12b23f..07d03bec05 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py
@@ -30,11 +30,6 @@ def create_webhook_node(
tenant_id: str = "test-tenant",
) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
- node_config = {
- "id": "webhook-node-1",
- "data": webhook_data.model_dump(),
- }
-
graph_init_params = GraphInitParams(
workflow_id="test-workflow",
graph_config={},
@@ -56,8 +51,8 @@ def create_webhook_node(
)
node = TriggerWebhookNode(
- id="webhook-node-1",
- config=node_config,
+ node_id="webhook-node-1",
+ config=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -66,10 +61,6 @@ def create_webhook_node(
runtime_state.app_config = Mock()
runtime_state.app_config.tenant_id = tenant_id
- # Provide compatibility alias expected by node implementation
- # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests
- node.node_id = node.id
-
return node
diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
index 427afa96ec..b839490d3c 100644
--- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py
@@ -24,11 +24,6 @@ from tests.workflow_test_utils import build_test_variable_pool
def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode:
"""Helper function to create a webhook node with proper initialization."""
- node_config = {
- "id": "1",
- "data": webhook_data.model_dump(),
- }
-
graph_init_params = GraphInitParams(
workflow_id="1",
graph_config={},
@@ -48,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
start_at=0,
)
node = TriggerWebhookNode(
- id="1",
- config=node_config,
+ node_id="1",
+ config=webhook_data,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
@@ -57,9 +52,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
# Provide tenant_id for conversion path
runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})()
- # Compatibility alias for some nodes referencing `self.node_id`
- node.node_id = node.id
-
return node
@@ -225,7 +217,7 @@ def test_webhook_node_run_with_file_params():
"""Test webhook node execution with file parameter extraction."""
# Create mock file objects
file1 = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="image.jpg",
@@ -234,7 +226,7 @@ def test_webhook_node_run_with_file_params():
)
file2 = File(
- type=FileType.DOCUMENT,
+ file_type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file2",
filename="document.pdf",
@@ -269,8 +261,19 @@ def test_webhook_node_run_with_file_params():
# Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id
with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory:
- def _to_file(*, mapping):
- return File.model_validate(mapping)
+ def _to_file(*, mapping: dict[str, Any]) -> File:
+ return File(
+ file_id=mapping.get("id"),
+ file_type=FileType(mapping["type"]),
+ transfer_method=FileTransferMethod(mapping["transfer_method"]),
+ related_id=mapping.get("related_id"),
+ filename=mapping.get("filename"),
+ extension=mapping.get("extension"),
+ mime_type=mapping.get("mime_type"),
+ size=mapping.get("size", -1),
+ storage_key=mapping.get("storage_key", ""),
+ remote_url=mapping.get("url"),
+ )
mock_file_factory.side_effect = _to_file
result = node._run()
@@ -284,7 +287,7 @@ def test_webhook_node_run_with_file_params():
def test_webhook_node_run_mixed_parameters():
"""Test webhook node execution with mixed parameter types."""
file_obj = File(
- type=FileType.IMAGE,
+ file_type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="test.jpg",
@@ -317,8 +320,19 @@ def test_webhook_node_run_mixed_parameters():
# Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id
with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory:
- def _to_file(*, mapping):
- return File.model_validate(mapping)
+ def _to_file(*, mapping: dict[str, Any]) -> File:
+ return File(
+ file_id=mapping.get("id"),
+ file_type=FileType(mapping["type"]),
+ transfer_method=FileTransferMethod(mapping["transfer_method"]),
+ related_id=mapping.get("related_id"),
+ filename=mapping.get("filename"),
+ extension=mapping.get("extension"),
+ mime_type=mapping.get("mime_type"),
+ size=mapping.get("size", -1),
+ storage_key=mapping.get("storage_key", ""),
+ remote_url=mapping.get("url"),
+ )
mock_file_factory.side_effect = _to_file
result = node._run()
diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py
new file mode 100644
index 0000000000..8b5fceeb37
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py
@@ -0,0 +1,350 @@
+from types import SimpleNamespace
+
+import pytest
+from pydantic import BaseModel
+
+from core.workflow.human_input_adapter import (
+ DeliveryMethodType,
+ EmailDeliveryConfig,
+ EmailDeliveryMethod,
+ EmailRecipients,
+ WebAppDeliveryMethod,
+ _WebAppDeliveryConfig,
+ adapt_human_input_node_data_for_graph,
+ adapt_node_config_for_graph,
+ adapt_node_data_for_graph,
+ is_human_input_webapp_enabled,
+ parse_human_input_delivery_methods,
+)
+from graphon.enums import BuiltinNodeTypes
+from graphon.nodes.base.variable_template_parser import VariableTemplateParser
+
+
+def test_email_delivery_config_helpers_render_and_sanitize_text() -> None:
+ variable_pool = SimpleNamespace(
+ convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42"))
+ )
+
+ rendered = EmailDeliveryConfig.render_body_template(
+ body="Open {{#url#}} and use {{#node.value#}}",
+ url="https://example.com",
+ variable_pool=variable_pool,
+ )
+ sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team")
+ html = EmailDeliveryConfig.render_markdown_body(
+ "**Hello** [mail](mailto:test@example.com)"
+ )
+
+ assert rendered == "Open https://example.com and use 42"
+ assert sanitized == "Hello alert(1) Team"
+ assert "Hello" in html
+ assert " Team")
- html = EmailDeliveryConfig.render_markdown_body(
- "**Hello** [mail](mailto:test@example.com)"
- )
-
- assert rendered == "Open https://example.com and use 42"
- assert sanitized == "Hello alert(1) Team"
- assert "Hello" in html
- assert "