mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:06:51 +08:00
merge
This commit is contained in:
commit
3bb3670cb5
@ -20,11 +20,11 @@
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components
|
||||
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
vi.mock('@/app/components/base/ui/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
vi.mock('@langgenius/dify-ui/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { Button } from '@/app/components/base/ui/button'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
// They will render normally in tests
|
||||
```
|
||||
|
||||
|
||||
18
.github/workflows/pyrefly-diff-comment.yml
vendored
18
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -76,13 +76,11 @@ jobs:
|
||||
diff += '\\n\\n... (truncated) ...';
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body,
|
||||
});
|
||||
if (diff.trim()) {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: prNumber,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>',
|
||||
});
|
||||
}
|
||||
|
||||
34
.github/workflows/web-tests.yml
vendored
34
.github/workflows/web-tests.yml
vendored
@ -89,3 +89,37 @@ jobs:
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
dify-ui-test:
|
||||
name: dify-ui Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./packages/dify-ui
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Install Chromium for Browser Mode
|
||||
run: vp exec playwright install --with-deps chromium
|
||||
|
||||
- name: Run dify-ui tests
|
||||
run: vp test run --coverage --silent=passed-only
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: packages/dify-ui/coverage
|
||||
flags: dify-ui
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
1
api/constants/dsl_version.py
Normal file
1
api/constants/dsl_version.py
Normal file
@ -0,0 +1 @@
|
||||
CURRENT_APP_DSL_VERSION = "0.6.0"
|
||||
@ -45,7 +45,7 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
|
||||
@ -102,7 +102,7 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
value_type = workflow_draft_var.value_type
|
||||
return value_type.exposed_type().value
|
||||
return str(value_type.exposed_type())
|
||||
|
||||
|
||||
class FullContentDict(TypedDict):
|
||||
@ -122,7 +122,7 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
|
||||
|
||||
result: FullContentDict = {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"value_type": str(variable_file.value_type.exposed_type()),
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
@ -598,7 +598,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value_type": str(v.value_type.exposed_type()),
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
|
||||
@ -84,10 +84,10 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
return str(SegmentType(value).exposed_type())
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
|
||||
@ -42,7 +42,7 @@ from graphon.model_runtime.entities import (
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class ModelConfigConverter:
|
||||
|
||||
@ -18,7 +18,7 @@ from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -59,7 +59,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||
|
||||
|
||||
@ -12,13 +12,14 @@ from typing import TYPE_CHECKING, Literal
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from extensions.ext_storage import storage
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol
|
||||
from graphon.file.protocols import WorkflowFileRuntimeProtocol
|
||||
from graphon.file.runtime import set_workflow_file_runtime
|
||||
from graphon.http.protocols import HttpResponseProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.file import File
|
||||
@ -43,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@ -349,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
execution.total_tokens = runtime_state.total_tokens
|
||||
execution.total_steps = runtime_state.node_run_steps
|
||||
execution.outputs = execution.outputs or runtime_state.outputs
|
||||
execution.exceptions_count = runtime_state.exceptions_count
|
||||
execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count)
|
||||
|
||||
def _update_node_execution(
|
||||
self,
|
||||
|
||||
@ -352,11 +352,11 @@ class DatasourceManager:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.CUSTOM,
|
||||
file_type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
|
||||
@ -31,7 +31,7 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from graphon.model_runtime.runtime import ModelRuntime
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -363,7 +363,7 @@ class ProviderConfiguration(BaseModel):
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
@ -912,7 +912,7 @@ class ProviderConfiguration(BaseModel):
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
if key in provider_credential_secret_variables and isinstance(value, str):
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
@ -102,7 +102,7 @@ class TemplateTransformer(ABC):
|
||||
|
||||
@classmethod
|
||||
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
|
||||
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
|
||||
inputs_json_str = dumps_with_segments(inputs).encode()
|
||||
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
|
||||
return input_base64_encoded
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError
|
||||
from configs import dify_config
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from core.tools.errors import ToolSSRFError
|
||||
from graphon.http.response import HttpResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -267,4 +268,47 @@ class SSRFProxy:
|
||||
return patch(url=url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def _to_graphon_http_response(response: httpx.Response) -> HttpResponse:
|
||||
"""Convert an ``httpx`` response into Graphon's transport-agnostic wrapper."""
|
||||
return HttpResponse(
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
content=response.content,
|
||||
url=str(response.url) if response.url else None,
|
||||
reason_phrase=response.reason_phrase,
|
||||
fallback_text=response.text,
|
||||
)
|
||||
|
||||
|
||||
class GraphonSSRFProxy:
|
||||
"""Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``."""
|
||||
|
||||
@property
|
||||
def max_retries_exceeded_error(self) -> type[Exception]:
|
||||
return max_retries_exceeded_error
|
||||
|
||||
@property
|
||||
def request_error(self) -> type[Exception]:
|
||||
return request_error
|
||||
|
||||
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse:
|
||||
return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs))
|
||||
|
||||
|
||||
ssrf_proxy = SSRFProxy()
|
||||
graphon_ssrf_proxy = GraphonSSRFProxy()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import PluginCredentialType
|
||||
@ -18,15 +18,17 @@ from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFe
|
||||
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
|
||||
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
@ -168,7 +170,7 @@ class ModelInstance:
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@ -193,7 +195,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=list(prompt_messages),
|
||||
@ -213,7 +215,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@ -235,7 +237,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
@ -252,7 +254,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
self.model_type_instance.get_num_tokens,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
@ -277,7 +279,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@ -305,7 +307,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@ -324,7 +326,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
@ -340,7 +342,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
@ -357,14 +359,14 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
self.model_type_instance.invoke,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
|
||||
@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime):
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US
|
||||
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
|
||||
)
|
||||
elif icon_type.lower() == "icon_small_dark":
|
||||
if not provider_schema.icon_small_dark:
|
||||
raise ValueError(f"Provider {provider} does not have small dark icon.")
|
||||
file_name = (
|
||||
provider_schema.icon_small_dark.zh_Hans
|
||||
provider_schema.icon_small_dark.zh_hans
|
||||
if lang.lower() == "zh_hans"
|
||||
else provider_schema.icon_small_dark.en_US
|
||||
else provider_schema.icon_small_dark.en_us
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported icon type: {icon_type}.")
|
||||
|
||||
@ -10,7 +10,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class AgentHistoryPromptTransform(PromptTransform):
|
||||
|
||||
@ -14,7 +14,7 @@ from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from libs import helper
|
||||
from models.dataset import Embedding
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`).
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
_closed: bool
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str):
|
||||
"""Initialize with file path."""
|
||||
self._closed = False
|
||||
self.file_path = file_path
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor):
|
||||
elif not os.path.isfile(self.file_path):
|
||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Best-effort cleanup for downloaded temporary files."""
|
||||
if getattr(self, "_closed", False):
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
temp_file = getattr(self, "temp_file", None)
|
||||
if temp_file is None:
|
||||
return
|
||||
|
||||
try:
|
||||
close_result = temp_file.close()
|
||||
if inspect.isawaitable(close_result):
|
||||
close_awaitable = getattr(close_result, "close", None)
|
||||
if callable(close_awaitable):
|
||||
close_awaitable()
|
||||
except Exception:
|
||||
logger.debug("Failed to cleanup downloaded word temp file", exc_info=True)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "temp_file"):
|
||||
self.temp_file.close()
|
||||
self.close()
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load given path as single page."""
|
||||
|
||||
@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
try:
|
||||
# Create File object directly (similar to DatasetRetrieval)
|
||||
file_obj = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(
|
||||
|
||||
@ -68,7 +68,7 @@ from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.helper import parse_uuid_str_or_none
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
@ -517,11 +517,11 @@ class DatasetRetrieval:
|
||||
if attachments_with_bindings:
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
attachment_info = File(
|
||||
id=upload_file.id,
|
||||
file_id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
reference=build_file_reference(
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Any, Literal
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter
|
||||
from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer
|
||||
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
@ -8,7 +8,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
|
||||
@ -28,7 +28,7 @@ class ToolFileManager:
|
||||
def _build_graph_file_reference(tool_file: ToolFile) -> File:
|
||||
extension = guess_extension(tool_file.mimetype) or ".bin"
|
||||
return File(
|
||||
type=get_file_type_by_mime_type(tool_file.mimetype),
|
||||
file_type=get_file_type_by_mime_type(tool_file.mimetype),
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||
|
||||
@ -1082,7 +1082,12 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.value
|
||||
if not isinstance(variable_selector, list) or not all(
|
||||
isinstance(selector_part, str) for selector_part in variable_selector
|
||||
):
|
||||
raise ToolParameterError("Variable tool input must be a variable selector")
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
|
||||
@ -21,7 +21,7 @@ from graphon.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from models.tools import ToolModelInvoke
|
||||
|
||||
|
||||
@ -357,7 +357,10 @@ class WorkflowTool(Tool):
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id"))
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
transfer_method_value = file_dict.get("transfer_method")
|
||||
if not isinstance(transfer_method_value, str):
|
||||
raise ValueError("Workflow file mapping is missing a valid transfer_method")
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_id
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""Workflow-layer adapters for legacy human-input payload keys.
|
||||
"""Workflow-to-Graphon adapters for persisted node payloads.
|
||||
|
||||
Stored workflow graphs and editor payloads may still use Dify-specific human
|
||||
input recipient keys. Normalize them here before handing configs to
|
||||
`graphon` so graph-owned models only see graph-neutral field names.
|
||||
Stored workflow graphs and editor payloads still contain a small set of
|
||||
Dify-owned field spellings and value shapes. Adapt them here before handing the
|
||||
payload to Graphon so Graphon-owned models only see current contracts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -185,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_data)
|
||||
if normalized is None:
|
||||
raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}")
|
||||
@ -215,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas
|
||||
|
||||
|
||||
def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]:
|
||||
normalized = normalize_human_input_node_data_for_graph(node_data)
|
||||
normalized = adapt_human_input_node_data_for_graph(node_data)
|
||||
raw_delivery_methods = normalized.get("delivery_methods")
|
||||
if not isinstance(raw_delivery_methods, list):
|
||||
return []
|
||||
@ -229,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b
|
||||
return False
|
||||
|
||||
|
||||
def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_data)
|
||||
if normalized is None:
|
||||
raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}")
|
||||
|
||||
if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT:
|
||||
return normalized
|
||||
return normalize_human_input_node_data_for_graph(normalized)
|
||||
node_type = normalized.get("type")
|
||||
if node_type == BuiltinNodeTypes.HUMAN_INPUT:
|
||||
return adapt_human_input_node_data_for_graph(normalized)
|
||||
if node_type == BuiltinNodeTypes.TOOL:
|
||||
return _adapt_tool_node_data_for_graph(normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]:
|
||||
normalized = _copy_mapping(node_config)
|
||||
if normalized is None:
|
||||
raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}")
|
||||
@ -248,10 +251,65 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel)
|
||||
if data_mapping is None:
|
||||
return normalized
|
||||
|
||||
normalized["data"] = normalize_node_data_for_graph(data_mapping)
|
||||
normalized["data"] = adapt_node_data_for_graph(data_mapping)
|
||||
return normalized
|
||||
|
||||
|
||||
def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(node_data)
|
||||
|
||||
raw_tool_configurations = normalized.get("tool_configurations")
|
||||
if not isinstance(raw_tool_configurations, Mapping):
|
||||
return normalized
|
||||
|
||||
existing_tool_parameters = normalized.get("tool_parameters")
|
||||
normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {}
|
||||
normalized_tool_configurations: dict[str, Any] = {}
|
||||
found_legacy_tool_inputs = False
|
||||
|
||||
for name, value in raw_tool_configurations.items():
|
||||
if not isinstance(value, Mapping):
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
input_type = value.get("type")
|
||||
input_value = value.get("value")
|
||||
if input_type not in {"mixed", "variable", "constant"}:
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
found_legacy_tool_inputs = True
|
||||
normalized_tool_parameters.setdefault(name, dict(value))
|
||||
|
||||
flattened_value = _flatten_legacy_tool_configuration_value(
|
||||
input_type=input_type,
|
||||
input_value=input_value,
|
||||
)
|
||||
if flattened_value is not None:
|
||||
normalized_tool_configurations[name] = flattened_value
|
||||
|
||||
if not found_legacy_tool_inputs:
|
||||
return normalized
|
||||
|
||||
normalized["tool_parameters"] = normalized_tool_parameters
|
||||
normalized["tool_configurations"] = normalized_tool_configurations
|
||||
return normalized
|
||||
|
||||
|
||||
def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None:
|
||||
if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool):
|
||||
return input_value
|
||||
|
||||
if (
|
||||
input_type == "variable"
|
||||
and isinstance(input_value, list)
|
||||
and all(isinstance(item, str) for item in input_value)
|
||||
):
|
||||
return "{{#" + ".".join(input_value) + "#}}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(recipients)
|
||||
|
||||
@ -291,9 +349,9 @@ __all__ = [
|
||||
"MemberRecipient",
|
||||
"WebAppDeliveryMethod",
|
||||
"_WebAppDeliveryConfig",
|
||||
"adapt_human_input_node_data_for_graph",
|
||||
"adapt_node_config_for_graph",
|
||||
"adapt_node_data_for_graph",
|
||||
"is_human_input_webapp_enabled",
|
||||
"normalize_human_input_node_data_for_graph",
|
||||
"normalize_node_config_for_graph",
|
||||
"normalize_node_data_for_graph",
|
||||
"parse_human_input_delivery_methods",
|
||||
]
|
||||
@ -15,12 +15,12 @@ from core.helper.code_executor.code_executor import (
|
||||
CodeExecutionError,
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.helper.ssrf_proxy import graphon_ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from core.workflow.human_input_compat import normalize_node_config_for_graph
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.node_runtime import (
|
||||
DifyFileReferenceFactory,
|
||||
DifyHumanInputNodeRuntime,
|
||||
@ -46,7 +46,7 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.graph.graph import NodeFactory
|
||||
from graphon.model_runtime.memory import PromptMessageMemory
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.code.code_node import WorkflowCodeExecutor
|
||||
from graphon.nodes.code.entities import CodeLanguage
|
||||
@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]
|
||||
|
||||
|
||||
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
"""Resolve the production node class for the requested type/version."""
|
||||
node_mapping = get_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_http_client = graphon_ssrf_proxy
|
||||
self._bound_tool_file_manager_factory = lambda: DifyToolFileManager(
|
||||
self._dify_context,
|
||||
conversation_id_getter=self._conversation_id,
|
||||
@ -364,10 +365,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
# Graph configs are initially validated against permissive shared node data.
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@ -391,7 +396,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
@ -405,7 +410,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
},
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=True,
|
||||
include_llm_file_saver=True,
|
||||
@ -415,7 +420,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
),
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
|
||||
node_class=node_class,
|
||||
node_data=node_data,
|
||||
node_data=resolved_node_data,
|
||||
wrap_model_instance=True,
|
||||
include_http_client=False,
|
||||
include_llm_file_saver=False,
|
||||
@ -436,8 +441,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
@ -448,7 +453,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class.
|
||||
"""
|
||||
return node_class.validate_node_data(node_data)
|
||||
validate_node_data = getattr(node_class, "validate_node_data", None)
|
||||
if callable(validate_node_data):
|
||||
return cast("BaseNodeData", validate_node_data(node_data))
|
||||
return node_data
|
||||
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -41,7 +41,7 @@ from graphon.model_runtime.entities.llm_entities import (
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
from graphon.nodes.llm.runtime_protocols import (
|
||||
PreparedLLMProtocol,
|
||||
@ -64,7 +64,7 @@ from models.dataset import SegmentAttachmentBinding
|
||||
from models.model import UploadFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .human_input_compat import (
|
||||
from .human_input_adapter import (
|
||||
BoundRecipient,
|
||||
DeliveryChannelConfig,
|
||||
DeliveryMethodType,
|
||||
@ -173,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
|
||||
return self._model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResult: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Mapping[str, Any],
|
||||
tools: Sequence[PromptMessageTool] | None,
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -190,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[False],
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
json_schema: Mapping[str, Any],
|
||||
model_parameters: Mapping[str, Any],
|
||||
stop: Sequence[str] | None,
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
@ -35,18 +34,18 @@ class AgentNode(Node[AgentNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: AgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
strategy_resolver: AgentStrategyResolver,
|
||||
presentation_provider: AgentStrategyPresentationProvider,
|
||||
runtime_support: AgentRuntimeSupport,
|
||||
message_transformer: AgentMessageTransformer,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@ -7,7 +7,6 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeExecutionType,
|
||||
@ -36,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: DatasourceNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@ -7,7 +7,6 @@ from core.rag.index_processor.index_processor_base import SummaryIndexSettingDic
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
@ -32,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: KnowledgeIndexNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> None:
|
||||
super().__init__(id, config, graph_init_params, graph_runtime_state)
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.index_processor = IndexProcessor()
|
||||
self.summary_index_service = SummaryIndex()
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from core.rag.data_post_processor.data_post_processor import RerankingModelDict,
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@ -50,6 +49,18 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None:
|
||||
if value is None or isinstance(value, (str, float)):
|
||||
return value
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def _normalize_metadata_filter_sequence_item(value: object) -> str:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
@ -59,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
node_id: str,
|
||||
config: KnowledgeRetrievalNodeData,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@ -282,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
resolved_conditions: list[Condition] = []
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
resolved_value: str | Sequence[str] | int | float | None
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = segment_group.value[0].to_object()
|
||||
resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values = []
|
||||
for v in value: # type: ignore
|
||||
resolved_values: list[str] = []
|
||||
for v in value:
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(segment_group.value[0].to_object())
|
||||
resolved_values.append(
|
||||
_normalize_metadata_filter_sequence_item(segment_group.value[0].to_object())
|
||||
)
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
|
||||
@ -148,11 +148,11 @@ def _build_from_local_file(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension="." + row.extension,
|
||||
mime_type=row.mime_type,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=row.source_url,
|
||||
reference=build_file_reference(record_id=str(row.id)),
|
||||
@ -196,11 +196,11 @@ def _build_from_remote_url(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||
@ -222,9 +222,9 @@ def _build_from_remote_url(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=filename,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
mime_type=mime_type,
|
||||
@ -263,9 +263,9 @@ def _build_from_tool_file(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
file_id=mapping.get("id"),
|
||||
filename=tool_file.name,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||
@ -306,9 +306,9 @@ def _build_from_datasource_file(
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("datasource_file_id"),
|
||||
file_id=mapping.get("datasource_file_id"),
|
||||
filename=datasource_file.name,
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
remote_url=datasource_file.source_url,
|
||||
reference=build_file_reference(record_id=str(datasource_file.id)),
|
||||
|
||||
@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False):
|
||||
|
||||
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type.exposed_type().value
|
||||
return str(v.value_type.exposed_type())
|
||||
else:
|
||||
value_type = v.get("value_type")
|
||||
if value_type is None:
|
||||
raise ValueError("value_type is required but not provided")
|
||||
return value_type.exposed_type().value
|
||||
return str(value_type.exposed_type())
|
||||
|
||||
@ -57,10 +57,10 @@ class ConversationVariableResponse(ResponseModel):
|
||||
def _normalize_value_type(cls, value: Any) -> str:
|
||||
exposed_type = getattr(value, "exposed_type", None)
|
||||
if callable(exposed_type):
|
||||
return str(exposed_type().value)
|
||||
return str(exposed_type())
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return str(SegmentType(value).exposed_type().value)
|
||||
return str(SegmentType(value).exposed_type())
|
||||
except ValueError:
|
||||
return value
|
||||
try:
|
||||
|
||||
@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw):
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
"value": value.value,
|
||||
"value_type": value.value_type.exposed_type().value,
|
||||
"value_type": str(value.value_type.exposed_type()),
|
||||
"description": value.description,
|
||||
}
|
||||
if isinstance(value, dict):
|
||||
|
||||
@ -6,8 +6,8 @@ from flask_login import current_user
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.http_client_pooling import get_pooled_http_client
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource):
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
with session_factory.create_session() as session:
|
||||
data_source_binding = session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
)
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
session.add(new_data_source_binding)
|
||||
session.commit()
|
||||
|
||||
def save_internal_access_token(self, access_token: str) -> None:
|
||||
workspace_name = self.notion_workspace_name(access_token)
|
||||
@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource):
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
with session_factory.create_session() as session:
|
||||
data_source_binding = session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.access_token == access_token,
|
||||
)
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
session.add(new_data_source_binding)
|
||||
session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str) -> None:
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.id == binding_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
with session_factory.create_session() as session:
|
||||
data_source_binding = session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.id == binding_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
|
||||
pages: list[NotionPageSummary] = []
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@ -36,24 +37,27 @@ class WorkflowComment(Base):
|
||||
|
||||
__tablename__ = "workflow_comments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(db.Float)
|
||||
position_y: Mapped[float] = mapped_column(db.Float)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(sa.Float)
|
||||
position_y: Mapped[float] = mapped_column(sa.Float)
|
||||
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
|
||||
resolved: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
|
||||
# Relationships
|
||||
@ -143,23 +147,22 @@ class WorkflowCommentReply(Base):
|
||||
|
||||
__tablename__ = "workflow_comment_replies"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
Index("comment_replies_comment_idx", "comment_id"),
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
|
||||
comment: Mapped["WorkflowComment"] = relationship(
|
||||
"WorkflowComment", back_populates="replies")
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@ -187,24 +190,23 @@ class WorkflowCommentMention(Base):
|
||||
|
||||
__tablename__ = "workflow_comment_mentions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
Index("comment_mentions_comment_idx", "comment_id"),
|
||||
Index("comment_mentions_reply_idx", "reply_id"),
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, server_default=sa.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
|
||||
comment: Mapped["WorkflowComment"] = relationship(
|
||||
"WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]
|
||||
] = relationship("WorkflowCommentReply")
|
||||
|
||||
@property
|
||||
def mentioned_user_account(self):
|
||||
|
||||
@ -6,7 +6,7 @@ import sqlalchemy as sa
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from core.workflow.human_input_compat import DeliveryMethodType
|
||||
from core.workflow.human_input_adapter import DeliveryMethodType
|
||||
from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.helper import generate_string
|
||||
|
||||
|
||||
@ -5,7 +5,8 @@ from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.file_reference import parse_file_reference
|
||||
from graphon.file import File, FileTransferMethod
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@ -43,6 +44,124 @@ def resolve_file_mapping_tenant_id(
|
||||
return tenant_resolver()
|
||||
|
||||
|
||||
def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File:
|
||||
"""Build a graph `File` directly from serialized metadata."""
|
||||
|
||||
def _coerce_file_type(value: Any) -> FileType:
|
||||
if isinstance(value, FileType):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return FileType.value_of(value)
|
||||
raise ValueError("file type is required in file mapping")
|
||||
|
||||
mapping = dict(file_mapping)
|
||||
transfer_method_value = mapping.get("transfer_method")
|
||||
if isinstance(transfer_method_value, FileTransferMethod):
|
||||
transfer_method = transfer_method_value
|
||||
elif isinstance(transfer_method_value, str):
|
||||
transfer_method = FileTransferMethod.value_of(transfer_method_value)
|
||||
else:
|
||||
raise ValueError("transfer_method is required in file mapping")
|
||||
|
||||
file_id = mapping.get("file_id")
|
||||
if not isinstance(file_id, str) or not file_id:
|
||||
legacy_id = mapping.get("id")
|
||||
file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None
|
||||
|
||||
related_id = resolve_file_record_id(mapping)
|
||||
if related_id is None:
|
||||
raw_related_id = mapping.get("related_id")
|
||||
related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None
|
||||
|
||||
remote_url = mapping.get("remote_url")
|
||||
if not isinstance(remote_url, str) or not remote_url:
|
||||
url = mapping.get("url")
|
||||
remote_url = url if isinstance(url, str) and url else None
|
||||
|
||||
reference = mapping.get("reference")
|
||||
if not isinstance(reference, str) or not reference:
|
||||
reference = None
|
||||
|
||||
filename = mapping.get("filename")
|
||||
if not isinstance(filename, str):
|
||||
filename = None
|
||||
|
||||
extension = mapping.get("extension")
|
||||
if not isinstance(extension, str):
|
||||
extension = None
|
||||
|
||||
mime_type = mapping.get("mime_type")
|
||||
if not isinstance(mime_type, str):
|
||||
mime_type = None
|
||||
|
||||
size = mapping.get("size", -1)
|
||||
if not isinstance(size, int):
|
||||
size = -1
|
||||
|
||||
storage_key = mapping.get("storage_key")
|
||||
if not isinstance(storage_key, str):
|
||||
storage_key = None
|
||||
|
||||
tenant_id = mapping.get("tenant_id")
|
||||
if not isinstance(tenant_id, str):
|
||||
tenant_id = None
|
||||
|
||||
dify_model_identity = mapping.get("dify_model_identity")
|
||||
if not isinstance(dify_model_identity, str):
|
||||
dify_model_identity = FILE_MODEL_IDENTITY
|
||||
|
||||
tool_file_id = mapping.get("tool_file_id")
|
||||
if not isinstance(tool_file_id, str):
|
||||
tool_file_id = None
|
||||
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if not isinstance(upload_file_id, str):
|
||||
upload_file_id = None
|
||||
|
||||
datasource_file_id = mapping.get("datasource_file_id")
|
||||
if not isinstance(datasource_file_id, str):
|
||||
datasource_file_id = None
|
||||
|
||||
return File(
|
||||
file_id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url,
|
||||
reference=reference,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
storage_key=storage_key,
|
||||
dify_model_identity=dify_model_identity,
|
||||
url=remote_url,
|
||||
tool_file_id=tool_file_id,
|
||||
upload_file_id=upload_file_id,
|
||||
datasource_file_id=datasource_file_id,
|
||||
)
|
||||
|
||||
|
||||
def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any:
|
||||
"""Recursively rebuild serialized graph file payloads into `File` objects.
|
||||
|
||||
`graphon` 0.2.2 no longer accepts legacy serialized file mappings via
|
||||
`model_validate_json()`. Dify keeps this recovery path at the model boundary
|
||||
so historical JSON blobs remain readable without reintroducing global graph
|
||||
patches or test-local coercion.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [rebuild_serialized_graph_files_without_lookup(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
if maybe_file_object(value):
|
||||
return build_file_from_mapping_without_lookup(file_mapping=value)
|
||||
return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def build_file_from_stored_mapping(
|
||||
*,
|
||||
file_mapping: Mapping[str, Any],
|
||||
@ -76,12 +195,7 @@ def build_file_from_stored_mapping(
|
||||
pass
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
|
||||
remote_url = mapping.get("remote_url")
|
||||
if not isinstance(remote_url, str) or not remote_url:
|
||||
url = mapping.get("url")
|
||||
if isinstance(url, str) and url:
|
||||
mapping["remote_url"] = url
|
||||
return File.model_validate(mapping)
|
||||
return build_file_from_mapping_without_lookup(file_mapping=mapping)
|
||||
|
||||
return file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
|
||||
@ -24,7 +24,7 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.human_input_compat import normalize_node_config_for_graph
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.variable_prefixes import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
@ -64,7 +64,10 @@ from .base import Base, DefaultFieldsDCMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
from .utils.file_input_compat import build_file_from_stored_mapping
|
||||
from .utils.file_input_compat import (
|
||||
build_file_from_mapping_without_lookup,
|
||||
build_file_from_stored_mapping,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -292,7 +295,7 @@ class Workflow(Base): # bug
|
||||
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise NodeNotFoundError(node_id)
|
||||
return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
|
||||
return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
|
||||
@staticmethod
|
||||
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
|
||||
@ -1690,7 +1693,7 @@ class WorkflowDraftVariable(Base):
|
||||
return cast(Any, value)
|
||||
normalized_file = dict(value)
|
||||
normalized_file.pop("tenant_id", None)
|
||||
return File.model_validate(normalized_file)
|
||||
return build_file_from_mapping_without_lookup(file_mapping=normalized_file)
|
||||
elif isinstance(value, list) and value:
|
||||
value_list = cast(list[Any], value)
|
||||
first: Any = value_list[0]
|
||||
@ -1700,7 +1703,7 @@ class WorkflowDraftVariable(Base):
|
||||
for item in value_list:
|
||||
normalized_file = dict(cast(dict[str, Any], item))
|
||||
normalized_file.pop("tenant_id", None)
|
||||
file_list.append(File.model_validate(normalized_file))
|
||||
file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file))
|
||||
return cast(Any, file_list)
|
||||
else:
|
||||
return cast(Any, value)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Tencent APM tracing implementation with separated concerns
|
||||
"""
|
||||
"""Tencent APM tracing with idempotent client cleanup."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
"""
|
||||
Tencent APM trace implementation with single responsibility principle.
|
||||
Acts as a coordinator that delegates specific tasks to specialized classes.
|
||||
|
||||
The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen
|
||||
explicitly in tests or implicitly during garbage collection, so shutdown
|
||||
must be safe to call multiple times.
|
||||
"""
|
||||
|
||||
trace_client: TencentTraceClient
|
||||
_closed: bool
|
||||
|
||||
def __init__(self, tencent_config: TencentConfig):
|
||||
super().__init__(tencent_config)
|
||||
self._closed = False
|
||||
self.trace_client = TencentTraceClient(
|
||||
service_name=tencent_config.service_name,
|
||||
endpoint=tencent_config.endpoint,
|
||||
@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record message trace duration")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure proper cleanup on garbage collection."""
|
||||
def close(self) -> None:
|
||||
"""Synchronously and idempotently shutdown the underlying trace client."""
|
||||
if getattr(self, "_closed", False):
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
trace_client = getattr(self, "trace_client", None)
|
||||
if trace_client is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if hasattr(self, "trace_client"):
|
||||
self.trace_client.shutdown()
|
||||
shutdown_result = trace_client.shutdown()
|
||||
if inspect.isawaitable(shutdown_result):
|
||||
close_awaitable = getattr(shutdown_result, "close", None)
|
||||
if callable(close_awaitable):
|
||||
close_awaitable()
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure best-effort cleanup on garbage collection without retrying shutdown."""
|
||||
self.close()
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import gc
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
import warnings
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dify_trace_tencent.config import TencentConfig
|
||||
@ -632,13 +634,38 @@ class TestTencentDataTrace:
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
tencent_data_trace._record_message_trace_duration(trace_info)
|
||||
|
||||
def test_del(self, tencent_data_trace):
|
||||
def test_close(self, tencent_data_trace):
|
||||
client = tencent_data_trace.trace_client
|
||||
tencent_data_trace.__del__()
|
||||
tencent_data_trace.close()
|
||||
client.shutdown.assert_called_once()
|
||||
|
||||
def test_del_exception(self, tencent_data_trace):
|
||||
def test_close_is_idempotent(self, tencent_data_trace):
|
||||
client = tencent_data_trace.trace_client
|
||||
|
||||
tencent_data_trace.close()
|
||||
tencent_data_trace.close()
|
||||
|
||||
client.shutdown.assert_called_once()
|
||||
|
||||
def test_close_exception(self, tencent_data_trace):
|
||||
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.__del__()
|
||||
tencent_data_trace.close()
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
|
||||
def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
|
||||
shutdown = AsyncMock()
|
||||
tencent_data_trace.trace_client.shutdown = shutdown
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
tencent_data_trace.close()
|
||||
gc.collect()
|
||||
|
||||
shutdown.assert_called_once()
|
||||
assert not [
|
||||
warning
|
||||
for warning in caught
|
||||
if issubclass(warning.category, RuntimeWarning)
|
||||
and "AsyncMockMixin._execute_mock_call" in str(warning.message)
|
||||
]
|
||||
|
||||
@ -44,7 +44,7 @@ dependencies = [
|
||||
|
||||
# Emerging: newer and fast-moving, use compatible pins
|
||||
"fastopenapi[flask]~=0.7.0",
|
||||
"graphon~=0.1.2",
|
||||
"graphon~=0.2.2",
|
||||
"httpx-sse~=0.4.0",
|
||||
"json-repair~=0.59.2",
|
||||
]
|
||||
@ -173,7 +173,7 @@ dev = [
|
||||
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.60.0",
|
||||
"pyrefly>=0.61.1",
|
||||
"xinference-client>=2.4.0",
|
||||
]
|
||||
|
||||
|
||||
@ -7,8 +7,8 @@ from sqlalchemy import select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
@ -33,67 +33,68 @@ def mail_clean_document_notify_task():
|
||||
|
||||
# send document clean notify mail
|
||||
try:
|
||||
dataset_auto_disable_logs = db.session.scalars(
|
||||
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
|
||||
).all()
|
||||
# group by tenant_id
|
||||
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
|
||||
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
||||
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
|
||||
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
|
||||
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
|
||||
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
|
||||
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
plan = features.billing.subscription.plan
|
||||
if plan != CloudPlan.SANDBOX:
|
||||
knowledge_details = []
|
||||
# check tenant
|
||||
tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
# check current owner
|
||||
current_owner_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
)
|
||||
if not current_owner_join:
|
||||
continue
|
||||
account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
|
||||
if not account:
|
||||
continue
|
||||
with session_factory.create_session() as session:
|
||||
dataset_auto_disable_logs = session.scalars(
|
||||
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False))
|
||||
).all()
|
||||
# group by tenant_id
|
||||
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
|
||||
for dataset_auto_disable_log in dataset_auto_disable_logs:
|
||||
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
|
||||
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
|
||||
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
|
||||
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
|
||||
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
plan = features.billing.subscription.plan
|
||||
if plan != CloudPlan.SANDBOX:
|
||||
knowledge_details = []
|
||||
# check tenant
|
||||
tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id))
|
||||
if not tenant:
|
||||
continue
|
||||
# check current owner
|
||||
current_owner_join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
)
|
||||
if not current_owner_join:
|
||||
continue
|
||||
account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id))
|
||||
if not account:
|
||||
continue
|
||||
|
||||
dataset_auto_dataset_map = {} # type: ignore
|
||||
dataset_auto_dataset_map = {} # type: ignore
|
||||
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
|
||||
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
|
||||
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
|
||||
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
|
||||
dataset_auto_disable_log.document_id
|
||||
)
|
||||
|
||||
for dataset_id, document_ids in dataset_auto_dataset_map.items():
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id))
|
||||
if dataset:
|
||||
document_count = len(document_ids)
|
||||
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
|
||||
if knowledge_details:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
|
||||
language_code="en-US",
|
||||
to=account.email,
|
||||
template_context={
|
||||
"userName": account.email,
|
||||
"knowledge_details": knowledge_details,
|
||||
"url": url,
|
||||
},
|
||||
)
|
||||
|
||||
# update notified to True
|
||||
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
|
||||
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
|
||||
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
|
||||
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
|
||||
dataset_auto_disable_log.document_id
|
||||
)
|
||||
|
||||
for dataset_id, document_ids in dataset_auto_dataset_map.items():
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id))
|
||||
if dataset:
|
||||
document_count = len(document_ids)
|
||||
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
|
||||
if knowledge_details:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
|
||||
language_code="en-US",
|
||||
to=account.email,
|
||||
template_context={
|
||||
"userName": account.email,
|
||||
"knowledge_details": knowledge_details,
|
||||
"url": url,
|
||||
},
|
||||
)
|
||||
|
||||
# update notified to True
|
||||
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
|
||||
dataset_auto_disable_log.notified = True
|
||||
db.session.commit()
|
||||
dataset_auto_disable_log.notified = True
|
||||
session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
|
||||
@ -17,6 +17,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.dsl_version import CURRENT_APP_DSL_VERSION
|
||||
from core.helper import ssrf_proxy
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.trigger.constants import (
|
||||
@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.6.0"
|
||||
CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION
|
||||
|
||||
|
||||
class Import(BaseModel):
|
||||
|
||||
@ -16,7 +16,7 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_was_created, app_was_deleted, app_was_updated
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
@ -30,7 +30,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
|
||||
@ -3,6 +3,7 @@ from enum import StrEnum
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from configs import dify_config
|
||||
from constants.dsl_version import CURRENT_APP_DSL_VERSION
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from enums.hosted_provider import HostedTrialProvider
|
||||
from services.billing_service import BillingService
|
||||
@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel):
|
||||
|
||||
|
||||
class SystemFeatureModel(BaseModel):
|
||||
app_dsl_version: str = ""
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ""
|
||||
enable_marketplace: bool = False
|
||||
@ -225,6 +227,7 @@ class FeatureService:
|
||||
@classmethod
|
||||
def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel:
|
||||
system_features = SystemFeatureModel()
|
||||
system_features.app_dsl_version = CURRENT_APP_DSL_VERSION
|
||||
|
||||
cls._fulfill_system_params_from_env(system_features)
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
|
||||
@ -476,7 +476,7 @@ class RagPipelineService:
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
node_type_enum = NodeType(node_type)
|
||||
node_type_enum: NodeType = node_type
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
|
||||
# return default block config
|
||||
|
||||
@ -169,7 +169,7 @@ class VariableTruncator(BaseTruncator):
|
||||
return TruncationResult(StringSegment(value=fallback_result.value), True)
|
||||
|
||||
# Apply final fallback - convert to JSON string and truncate
|
||||
json_str = dumps_with_segments(result.value, ensure_ascii=False)
|
||||
json_str = dumps_with_segments(result.value)
|
||||
if len(json_str) > self._max_size_bytes:
|
||||
json_str = json_str[: self._max_size_bytes] + "..."
|
||||
return TruncationResult(result=StringSegment(value=json_str), truncated=True)
|
||||
|
||||
@ -146,7 +146,7 @@ class DraftVarLoader(VariableLoader):
|
||||
variable = segment_to_variable(
|
||||
segment=segment,
|
||||
selector=draft_var.get_selector(),
|
||||
id=draft_var.id,
|
||||
variable_id=draft_var.id,
|
||||
name=draft_var.name,
|
||||
description=draft_var.description,
|
||||
)
|
||||
@ -180,7 +180,7 @@ class DraftVarLoader(VariableLoader):
|
||||
variable = segment_to_variable(
|
||||
segment=segment,
|
||||
selector=draft_var.get_selector(),
|
||||
id=draft_var.id,
|
||||
variable_id=draft_var.id,
|
||||
name=draft_var.name,
|
||||
description=draft_var.description,
|
||||
)
|
||||
@ -191,7 +191,7 @@ class DraftVarLoader(VariableLoader):
|
||||
variable = segment_to_variable(
|
||||
segment=segment,
|
||||
selector=draft_var.get_selector(),
|
||||
id=draft_var.id,
|
||||
variable_id=draft_var.id,
|
||||
name=draft_var.name,
|
||||
description=draft_var.description,
|
||||
)
|
||||
@ -1075,7 +1075,7 @@ class DraftVariableSaver:
|
||||
filename = f"{self._generate_filename(name)}.txt"
|
||||
else:
|
||||
# For other types, store as JSON
|
||||
original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False)
|
||||
original_content_serialized = dumps_with_segments(value_seg.value)
|
||||
content_type = "application/json"
|
||||
filename = f"{self._generate_filename(name)}.json"
|
||||
|
||||
|
||||
@ -18,9 +18,9 @@ from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly,
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES, is_trigger_node_type
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
DeliveryChannelConfig,
|
||||
normalize_human_input_node_data_for_graph,
|
||||
adapt_human_input_node_data_for_graph,
|
||||
parse_human_input_delivery_methods,
|
||||
)
|
||||
from core.workflow.node_factory import (
|
||||
@ -112,7 +112,8 @@ class WorkflowService:
|
||||
def __init__(self, session_maker: sessionmaker | None = None):
|
||||
"""Initialize WorkflowService with repository dependencies."""
|
||||
if session_maker is None:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
session_maker = sessionmaker(
|
||||
bind=db.engine, expire_on_commit=False)
|
||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker
|
||||
)
|
||||
@ -216,7 +217,8 @@ class WorkflowService:
|
||||
if not app_ids:
|
||||
return set()
|
||||
|
||||
stmt = select(App.id).where(App.id.in_(app_ids), App.tenant_id == tenant_id)
|
||||
stmt = select(App.id).where(App.id.in_(
|
||||
app_ids), App.tenant_id == tenant_id)
|
||||
return {str(app_id) for app_id in db.session.scalars(stmt).all()}
|
||||
|
||||
def get_all_published_workflow(
|
||||
@ -300,7 +302,8 @@ class WorkflowService:
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(Workflow.created_at.desc()).limit(limit + 1).offset((page - 1) * limit)
|
||||
stmt = stmt.order_by(Workflow.created_at.desc()).limit(
|
||||
limit + 1).offset((page - 1) * limit)
|
||||
|
||||
workflows = session.scalars(stmt).all()
|
||||
|
||||
@ -332,7 +335,8 @@ class WorkflowService:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(app_model=app_model, features=features)
|
||||
self.validate_features_structure(
|
||||
app_model=app_model, features=features)
|
||||
|
||||
# validate graph structure
|
||||
self.validate_graph_structure(graph=graph)
|
||||
@ -364,7 +368,8 @@ class WorkflowService:
|
||||
db.session.commit()
|
||||
|
||||
# trigger app workflow events
|
||||
app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=workflow)
|
||||
app_draft_workflow_was_synced.send(
|
||||
app_model, synced_draft_workflow=workflow)
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
@ -432,7 +437,8 @@ class WorkflowService:
|
||||
raise ValueError("No draft workflow found.")
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(app_model=app_model, features=features)
|
||||
self.validate_features_structure(
|
||||
app_model=app_model, features=features)
|
||||
|
||||
workflow.features = json.dumps(features)
|
||||
workflow.updated_by = account.id
|
||||
@ -453,11 +459,13 @@ class WorkflowService:
|
||||
Secret environment variables are copied server-side from the selected
|
||||
published workflow so the normal draft sync flow stays stateless.
|
||||
"""
|
||||
source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id)
|
||||
source_workflow = self.get_published_workflow_by_id(
|
||||
app_model=app_model, workflow_id=workflow_id)
|
||||
if not source_workflow:
|
||||
raise WorkflowNotFoundError("Workflow not found.")
|
||||
|
||||
self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict)
|
||||
self.validate_features_structure(
|
||||
app_model=app_model, features=source_workflow.normalized_features_dict)
|
||||
self.validate_graph_structure(graph=source_workflow.graph_dict)
|
||||
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
@ -474,7 +482,8 @@ class WorkflowService:
|
||||
db.session.add(draft_workflow)
|
||||
|
||||
db.session.commit()
|
||||
app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow)
|
||||
app_draft_workflow_was_synced.send(
|
||||
app_model, synced_draft_workflow=draft_workflow)
|
||||
|
||||
return draft_workflow
|
||||
|
||||
@ -518,7 +527,8 @@ class WorkflowService:
|
||||
and is_trigger_node_type(node_type_str)
|
||||
)
|
||||
if trigger_node_count > 2:
|
||||
raise TriggerNodeLimitExceededError(count=trigger_node_count, limit=2)
|
||||
raise TriggerNodeLimitExceededError(
|
||||
count=trigger_node_count, limit=2)
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow.new(
|
||||
@ -540,7 +550,8 @@ class WorkflowService:
|
||||
session.add(workflow)
|
||||
|
||||
# trigger app workflow events
|
||||
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
|
||||
app_published_workflow_was_updated.send(
|
||||
app_model, published_workflow=workflow)
|
||||
|
||||
# return new workflow
|
||||
return workflow
|
||||
@ -597,7 +608,8 @@ class WorkflowService:
|
||||
session.add(workflow)
|
||||
|
||||
# trigger app workflow events
|
||||
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
|
||||
app_published_workflow_was_updated.send(
|
||||
app_model, published_workflow=workflow)
|
||||
|
||||
return workflow
|
||||
|
||||
@ -615,7 +627,8 @@ class WorkflowService:
|
||||
This endpoint only supports conversion between standard workflow and evaluation workflow.
|
||||
"""
|
||||
if target_type not in {WorkflowType.WORKFLOW, WorkflowType.EVALUATION}:
|
||||
raise ValueError("target_type must be either 'workflow' or 'evaluation'")
|
||||
raise ValueError(
|
||||
"target_type must be either 'workflow' or 'evaluation'")
|
||||
|
||||
if not app_model.workflow_id:
|
||||
raise WorkflowNotFoundError("Published workflow not found")
|
||||
@ -630,7 +643,8 @@ class WorkflowService:
|
||||
raise WorkflowNotFoundError("Published workflow not found")
|
||||
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise IsDraftWorkflowError("Current effective workflow cannot be a draft version.")
|
||||
raise IsDraftWorkflowError(
|
||||
"Current effective workflow cannot be a draft version.")
|
||||
|
||||
if workflow.type == target_type:
|
||||
return workflow
|
||||
@ -642,7 +656,8 @@ class WorkflowService:
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
|
||||
app_published_workflow_was_updated.send(
|
||||
app_model, published_workflow=workflow)
|
||||
|
||||
return workflow
|
||||
|
||||
@ -660,7 +675,8 @@ class WorkflowService:
|
||||
if not disallowed_nodes:
|
||||
return
|
||||
|
||||
formatted_nodes = ", ".join(f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes)
|
||||
formatted_nodes = ", ".join(
|
||||
f"{node_id}:{node_type}" for node_id, node_type in disallowed_nodes)
|
||||
raise ValueError(
|
||||
"Evaluation workflow cannot contain trigger or human-input nodes. "
|
||||
f"Found disallowed nodes: {formatted_nodes}"
|
||||
@ -698,12 +714,14 @@ class WorkflowService:
|
||||
)
|
||||
else:
|
||||
# Check default workspace credential for this provider
|
||||
self._check_default_tool_credential(workflow.tenant_id, provider)
|
||||
self._check_default_tool_credential(
|
||||
workflow.tenant_id, provider)
|
||||
|
||||
elif node_type == "agent":
|
||||
agent_params = node_data.get("agent_parameters", {})
|
||||
|
||||
model_config = agent_params.get("model", {}).get("value", {})
|
||||
model_config = agent_params.get(
|
||||
"model", {}).get("value", {})
|
||||
if model_config.get("provider") and model_config.get("model"):
|
||||
self._validate_llm_model_config(
|
||||
workflow.tenant_id, model_config["provider"], model_config["model"]
|
||||
@ -711,7 +729,8 @@ class WorkflowService:
|
||||
|
||||
# Validate load balancing credentials for agent model if load balancing is enabled
|
||||
agent_model_node_data = {"model": model_config}
|
||||
self._validate_load_balancing_credentials(workflow, agent_model_node_data, node_id)
|
||||
self._validate_load_balancing_credentials(
|
||||
workflow, agent_model_node_data, node_id)
|
||||
|
||||
# Validate agent tools
|
||||
tools = agent_params.get("tools", {}).get("value", [])
|
||||
@ -723,9 +742,11 @@ class WorkflowService:
|
||||
if credential_id:
|
||||
from core.helper.credential_utils import check_credential_policy_compliance
|
||||
|
||||
check_credential_policy_compliance(credential_id, provider, PluginCredentialType.TOOL)
|
||||
check_credential_policy_compliance(
|
||||
credential_id, provider, PluginCredentialType.TOOL)
|
||||
else:
|
||||
self._check_default_tool_credential(workflow.tenant_id, provider)
|
||||
self._check_default_tool_credential(
|
||||
workflow.tenant_id, provider)
|
||||
|
||||
elif node_type in ["llm", "knowledge_retrieval", "parameter_extractor", "question_classifier"]:
|
||||
model_config = node_data.get("model", {})
|
||||
@ -734,11 +755,14 @@ class WorkflowService:
|
||||
|
||||
if provider and model_name:
|
||||
# Validate that the provider+model combination can fetch valid credentials
|
||||
self._validate_llm_model_config(workflow.tenant_id, provider, model_name)
|
||||
self._validate_llm_model_config(
|
||||
workflow.tenant_id, provider, model_name)
|
||||
# Validate load balancing credentials if load balancing is enabled
|
||||
self._validate_load_balancing_credentials(workflow, node_data, node_id)
|
||||
self._validate_load_balancing_credentials(
|
||||
workflow, node_data, node_id)
|
||||
else:
|
||||
raise ValueError(f"Node {node_id} ({node_type}): Missing provider or model configuration")
|
||||
raise ValueError(
|
||||
f"Node {node_id} ({node_type}): Missing provider or model configuration")
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
@ -780,8 +804,10 @@ class WorkflowService:
|
||||
# If it fails, an exception will be raised
|
||||
|
||||
# Additionally, check the model status to ensure it's ACTIVE
|
||||
provider_configurations = assembly.provider_manager.get_configurations(tenant_id)
|
||||
models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM)
|
||||
provider_configurations = assembly.provider_manager.get_configurations(
|
||||
tenant_id)
|
||||
models = provider_configurations.get_models(
|
||||
provider=provider, model_type=ModelType.LLM)
|
||||
|
||||
target_model = None
|
||||
for model in models:
|
||||
@ -792,7 +818,8 @@ class WorkflowService:
|
||||
if target_model:
|
||||
target_model.raise_for_status()
|
||||
else:
|
||||
raise ValueError(f"Model {model_name} not found for provider {provider}")
|
||||
raise ValueError(
|
||||
f"Model {model_name} not found for provider {provider}")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
@ -840,7 +867,8 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Failed to validate default credential for tool provider {provider}: {str(e)}")
|
||||
|
||||
def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None:
|
||||
"""
|
||||
@ -862,7 +890,8 @@ class WorkflowService:
|
||||
# Check if this model has load balancing enabled
|
||||
if self._is_load_balancing_enabled(workflow.tenant_id, provider, model_name):
|
||||
# Get all load balancing configurations for this model
|
||||
load_balancing_configs = self._get_load_balancing_configs(workflow.tenant_id, provider, model_name)
|
||||
load_balancing_configs = self._get_load_balancing_configs(
|
||||
workflow.tenant_id, provider, model_name)
|
||||
# Validate each load balancing configuration
|
||||
try:
|
||||
for config in load_balancing_configs:
|
||||
@ -873,7 +902,8 @@ class WorkflowService:
|
||||
config["credential_id"], provider, PluginCredentialType.MODEL
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Invalid load balancing credentials for {provider}/{model_name}: {str(e)}")
|
||||
|
||||
def _is_load_balancing_enabled(self, tenant_id: str, provider: str, model_name: str) -> bool:
|
||||
"""
|
||||
@ -888,8 +918,10 @@ class WorkflowService:
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
# Get provider configurations
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=tenant_id)
|
||||
provider_configurations = provider_manager.get_configurations(tenant_id)
|
||||
provider_manager = create_plugin_provider_manager(
|
||||
tenant_id=tenant_id)
|
||||
provider_configurations = provider_manager.get_configurations(
|
||||
tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
|
||||
if not provider_configuration:
|
||||
@ -930,7 +962,8 @@ class WorkflowService:
|
||||
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
|
||||
)
|
||||
all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs)
|
||||
all_configs = cast(list[dict[str, Any]], configs) + \
|
||||
cast(list[dict[str, Any]], custom_configs)
|
||||
|
||||
return [config for config in all_configs if config.get("credential_id")]
|
||||
|
||||
@ -975,7 +1008,7 @@ class WorkflowService:
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
node_type_enum = NodeType(node_type)
|
||||
node_type_enum: NodeType = node_type
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
|
||||
# return default block config
|
||||
@ -994,7 +1027,8 @@ class WorkflowService:
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
)
|
||||
default_config = node_class.get_default_config(filters=resolved_filters or None)
|
||||
default_config = node_class.get_default_config(
|
||||
filters=resolved_filters or None)
|
||||
if not default_config:
|
||||
return {}
|
||||
|
||||
@ -1017,7 +1051,8 @@ class WorkflowService:
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=account.id)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(
|
||||
draft_workflow, user_id=account.id)
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
@ -1031,7 +1066,8 @@ class WorkflowService:
|
||||
workflow=draft_workflow,
|
||||
)
|
||||
if node_type == BuiltinNodeTypes.START:
|
||||
start_data = StartNodeData.model_validate(node_data, from_attributes=True)
|
||||
start_data = StartNodeData.model_validate(
|
||||
node_data, from_attributes=True)
|
||||
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
|
||||
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
|
||||
)
|
||||
@ -1066,7 +1102,8 @@ class WorkflowService:
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(
|
||||
node_config)
|
||||
if enclosing_node_type_and_id:
|
||||
_, enclosing_node_id = enclosing_node_type_and_id
|
||||
else:
|
||||
@ -1101,12 +1138,15 @@ class WorkflowService:
|
||||
)
|
||||
repository.save(node_execution)
|
||||
|
||||
workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id)
|
||||
workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(
|
||||
node_execution.id)
|
||||
if workflow_node_execution is None:
|
||||
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
|
||||
raise ValueError(
|
||||
f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||
outputs = workflow_node_execution.load_full_outputs(
|
||||
session, storage)
|
||||
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
@ -1118,7 +1158,8 @@ class WorkflowService:
|
||||
node_execution_id=node_execution.id,
|
||||
user=account,
|
||||
)
|
||||
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
|
||||
draft_var_saver.save(
|
||||
process_data=node_execution.process_data, outputs=outputs)
|
||||
|
||||
enqueue_draft_node_execution_trace(
|
||||
execution=workflow_node_execution,
|
||||
@ -1245,7 +1286,8 @@ class WorkflowService:
|
||||
rendered_content, outputs, node_data.outputs_field_names()
|
||||
)
|
||||
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
|
||||
enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(
|
||||
node_config)
|
||||
enclosing_node_id = enclosing_node_type_and_id[1] if enclosing_node_type_and_id else None
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
@ -1280,7 +1322,7 @@ class WorkflowService:
|
||||
raise ValueError("Node type must be human-input.")
|
||||
|
||||
node_data = HumanInputNodeData.model_validate(
|
||||
normalize_human_input_node_data_for_graph(node_config["data"]),
|
||||
adapt_human_input_node_data_for_graph(node_config["data"]),
|
||||
from_attributes=True,
|
||||
)
|
||||
delivery_method = self._resolve_human_input_delivery_method(
|
||||
@ -1332,7 +1374,8 @@ class WorkflowService:
|
||||
try:
|
||||
test_service.send_test(context=context, method=delivery_method)
|
||||
except DeliveryTestUnsupportedError as exc:
|
||||
raise ValueError("Delivery method does not support test send.") from exc
|
||||
raise ValueError(
|
||||
"Delivery method does not support test send.") from exc
|
||||
except DeliveryTestError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
@ -1357,7 +1400,8 @@ class WorkflowService:
|
||||
rendered_content: str,
|
||||
resolved_default_values: Mapping[str, Any],
|
||||
) -> tuple[str, list[DeliveryTestEmailRecipient]]:
|
||||
repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
repo = HumanInputFormRepositoryImpl(
|
||||
tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
params = FormCreateParams(
|
||||
workflow_execution_id=None,
|
||||
node_id=node_id,
|
||||
@ -1377,7 +1421,8 @@ class WorkflowService:
|
||||
|
||||
with Session(bind=db.engine) as session:
|
||||
recipients = session.scalars(
|
||||
select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_id)
|
||||
select(HumanInputFormRecipient).where(
|
||||
HumanInputFormRecipient.form_id == form_id)
|
||||
).all()
|
||||
recipients_data: list[DeliveryTestEmailRecipient] = []
|
||||
for recipient in recipients:
|
||||
@ -1388,11 +1433,13 @@ class WorkflowService:
|
||||
try:
|
||||
payload = json.loads(recipient.recipient_payload)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.exception("Failed to parse human input recipient payload for delivery test.")
|
||||
logger.exception(
|
||||
"Failed to parse human input recipient payload for delivery test.")
|
||||
continue
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
recipients_data.append(DeliveryTestEmailRecipient(email=email, form_token=recipient.access_token))
|
||||
recipients_data.append(DeliveryTestEmailRecipient(
|
||||
email=email, form_token=recipient.access_token))
|
||||
return recipients_data
|
||||
|
||||
def _build_human_input_node(
|
||||
@ -1421,9 +1468,11 @@ class WorkflowService:
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
node_data = HumanInputNode.validate_node_data(
|
||||
adapt_human_input_node_data_for_graph(node_config["data"]))
|
||||
node = HumanInputNode(
|
||||
id=node_config["id"],
|
||||
config=node_config,
|
||||
node_id=node_config["id"],
|
||||
config=node_data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=DifyHumanInputNodeRuntime(run_context),
|
||||
@ -1441,7 +1490,8 @@ class WorkflowService:
|
||||
) -> VariablePool:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(
|
||||
workflow, user_id=user_id)
|
||||
|
||||
variable_pool = VariablePool()
|
||||
add_variables_to_pool(
|
||||
@ -1519,7 +1569,8 @@ class WorkflowService:
|
||||
Returns:
|
||||
WorkflowNodeExecution: The execution result
|
||||
"""
|
||||
node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn)
|
||||
node, node_run_result, run_succeeded, error = self._execute_node_safely(
|
||||
invoke_node_fn)
|
||||
|
||||
# Create base node execution
|
||||
node_execution = WorkflowNodeExecution(
|
||||
@ -1535,7 +1586,8 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
# Populate execution result data
|
||||
self._populate_execution_result(node_execution, node_run_result, run_succeeded, error)
|
||||
self._populate_execution_result(
|
||||
node_execution, node_run_result, run_succeeded, error)
|
||||
|
||||
return node_execution
|
||||
|
||||
@ -1564,7 +1616,8 @@ class WorkflowService:
|
||||
|
||||
# Apply error strategy if node failed
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.error_strategy:
|
||||
node_run_result = self._apply_error_strategy(node, node_run_result)
|
||||
node_run_result = self._apply_error_strategy(
|
||||
node, node_run_result)
|
||||
|
||||
run_succeeded = node_run_result.status in (
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -1595,7 +1648,8 @@ class WorkflowService:
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
error=node_run_result.error,
|
||||
inputs=node_run_result.inputs,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy},
|
||||
outputs=error_outputs,
|
||||
)
|
||||
|
||||
@ -1609,10 +1663,12 @@ class WorkflowService:
|
||||
"""Populate node execution with result data."""
|
||||
if run_succeeded and node_run_result:
|
||||
node_execution.inputs = (
|
||||
WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
|
||||
WorkflowEntry.handle_special_values(
|
||||
node_run_result.inputs) if node_run_result.inputs else None
|
||||
)
|
||||
node_execution.process_data = (
|
||||
WorkflowEntry.handle_special_values(node_run_result.process_data)
|
||||
WorkflowEntry.handle_special_values(
|
||||
node_run_result.process_data)
|
||||
if node_run_result.process_data
|
||||
else None
|
||||
)
|
||||
@ -1641,7 +1697,8 @@ class WorkflowService:
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
if app_model.mode not in {AppMode.CHAT, AppMode.COMPLETION}:
|
||||
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
raise ValueError(
|
||||
f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
|
||||
# convert to workflow
|
||||
new_app: App = workflow_converter.convert_to_workflow(
|
||||
@ -1678,7 +1735,8 @@ class WorkflowService:
|
||||
# start node and trigger node cannot coexist
|
||||
if BuiltinNodeTypes.START in node_types:
|
||||
if any(is_trigger_node_type(nt) for nt in node_types):
|
||||
raise ValueError("Start node and trigger nodes cannot coexist in the same workflow")
|
||||
raise ValueError(
|
||||
"Start node and trigger nodes cannot coexist in the same workflow")
|
||||
|
||||
for node in node_configs:
|
||||
node_data = node.get("data", {})
|
||||
@ -1713,7 +1771,8 @@ class WorkflowService:
|
||||
from graphon.nodes.human_input.entities import HumanInputNodeData
|
||||
|
||||
try:
|
||||
HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data))
|
||||
HumanInputNodeData.model_validate(
|
||||
adapt_human_input_node_data_for_graph(node_data))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid HumanInput node data: {str(e)}")
|
||||
|
||||
@ -1730,7 +1789,8 @@ class WorkflowService:
|
||||
:param data: Dictionary containing fields to update
|
||||
:return: Updated workflow or None if not found
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
@ -1759,7 +1819,8 @@ class WorkflowService:
|
||||
:raises: WorkflowInUseError if workflow is in use
|
||||
:raises: DraftWorkflowDeletionError if workflow is a draft version
|
||||
"""
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id, Workflow.tenant_id == tenant_id)
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id)
|
||||
workflow = session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
@ -1767,14 +1828,16 @@ class WorkflowService:
|
||||
|
||||
# Check if workflow is a draft version
|
||||
if workflow.version == Workflow.VERSION_DRAFT:
|
||||
raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")
|
||||
raise DraftWorkflowDeletionError(
|
||||
"Cannot delete draft workflow versions")
|
||||
|
||||
# Check if this workflow is currently referenced by an app
|
||||
app_stmt = select(App).where(App.workflow_id == workflow_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if app:
|
||||
# Cannot delete a workflow that's currently in use by an app
|
||||
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")
|
||||
raise WorkflowInUseError(
|
||||
f"Cannot delete workflow that is currently in use by app '{app.id}'")
|
||||
|
||||
# Don't use workflow.tool_published as it's not accurate for specific workflow versions
|
||||
# Check if there's a tool provider using this specific workflow version
|
||||
@ -1788,7 +1851,8 @@ class WorkflowService:
|
||||
|
||||
if tool_provider:
|
||||
# Cannot delete a workflow that's published as a tool
|
||||
raise WorkflowInUseError("Cannot delete workflow that is published as a tool")
|
||||
raise WorkflowInUseError(
|
||||
"Cannot delete workflow that is published as a tool")
|
||||
|
||||
session.delete(workflow)
|
||||
return True
|
||||
@ -1837,11 +1901,13 @@ def _setup_variable_pool(
|
||||
build_bootstrap_variables(
|
||||
system_variables=system_variable,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=cast(list[Variable], conversation_variables),
|
||||
conversation_variables=cast(
|
||||
list[Variable], conversation_variables),
|
||||
),
|
||||
)
|
||||
if is_start_node_type(node_type):
|
||||
add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs)
|
||||
add_node_inputs_to_pool(
|
||||
variable_pool, node_id=node_id, inputs=user_inputs)
|
||||
|
||||
return variable_pool
|
||||
|
||||
@ -1857,7 +1923,8 @@ def _rebuild_file_for_user_inputs_in_start_node(
|
||||
if variable.variable not in user_inputs:
|
||||
continue
|
||||
value = user_inputs[variable.variable]
|
||||
file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type)
|
||||
file = _rebuild_single_file(
|
||||
tenant_id=tenant_id, value=value, variable_entity_type=variable.type)
|
||||
inputs_copy[variable.variable] = file
|
||||
return inputs_copy
|
||||
|
||||
@ -1865,15 +1932,18 @@ def _rebuild_file_for_user_inputs_in_start_node(
|
||||
def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]:
|
||||
if variable_entity_type == VariableEntityType.FILE:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"expected dict for file object, got {type(value)}")
|
||||
raise ValueError(
|
||||
f"expected dict for file object, got {type(value)}")
|
||||
return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller)
|
||||
elif variable_entity_type == VariableEntityType.FILE_LIST:
|
||||
if not isinstance(value, list):
|
||||
raise ValueError(f"expected list for file list object, got {type(value)}")
|
||||
raise ValueError(
|
||||
f"expected list for file list object, got {type(value)}")
|
||||
if len(value) == 0:
|
||||
return []
|
||||
if not isinstance(value[0], dict):
|
||||
raise ValueError(f"expected dict for first element in the file list, got {type(value)}")
|
||||
raise ValueError(
|
||||
f"expected dict for first element in the file list, got {type(value)}")
|
||||
return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller)
|
||||
else:
|
||||
raise Exception("unreachable")
|
||||
|
||||
@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod
|
||||
from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_mail import mail
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult, StreamCompletedEvent
|
||||
|
||||
@ -69,19 +70,16 @@ def test_node_integration_minimal_stream(mocker):
|
||||
mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr)
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
"id": "n",
|
||||
"data": {
|
||||
"type": "datasource",
|
||||
"version": "1",
|
||||
"title": "Datasource",
|
||||
"provider_type": "plugin",
|
||||
"provider_name": "p",
|
||||
"plugin_id": "plug",
|
||||
"datasource_name": "ds",
|
||||
},
|
||||
},
|
||||
node_id="n",
|
||||
config=DatasourceNodeData(
|
||||
type="datasource",
|
||||
version="1",
|
||||
title="Datasource",
|
||||
provider_type="plugin",
|
||||
provider_name="p",
|
||||
plugin_id="plug",
|
||||
datasource_name="ds",
|
||||
),
|
||||
graph_init_params=_GP(),
|
||||
graph_runtime_state=_GS(vp),
|
||||
)
|
||||
|
||||
@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.code.code_node import CodeNode
|
||||
from graphon.nodes.code.entities import CodeNodeData
|
||||
from graphon.nodes.code.limits import CodeNodeLimits
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
@ -64,8 +65,8 @@ def init_code_node(code_config: dict):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = CodeNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=code_config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=CodeNodeData.model_validate(code_config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
code_executor=node_factory._code_executor,
|
||||
|
||||
@ -14,7 +14,7 @@ from core.workflow.system_variables import build_system_variables
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.file.file_manager import file_manager
|
||||
from graphon.graph import Graph
|
||||
from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig
|
||||
from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
@ -75,8 +75,8 @@ def init_http_node(config: dict):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=HttpRequestNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
@ -723,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=graph_config["nodes"][1],
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
http_request_config=HTTP_REQUEST_CONFIG,
|
||||
|
||||
@ -11,6 +11,7 @@ from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
from graphon.nodes.llm.entities import LLMNodeData
|
||||
from graphon.nodes.llm.file_saver import LLMFileSaver
|
||||
from graphon.nodes.llm.node import LLMNode
|
||||
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
@ -75,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
llm_file_saver = MagicMock(spec=LLMFileSaver)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=LLMNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
|
||||
@ -11,6 +11,7 @@ from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
|
||||
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
|
||||
@ -69,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=ParameterExtractorNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
|
||||
@ -6,6 +6,7 @@ from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from graphon.template_rendering import TemplateRenderError
|
||||
@ -86,8 +87,8 @@ def test_execute_template_transform():
|
||||
assert graph is not None
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=TemplateTransformNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
jinja2_template_renderer=_SimpleJinja2Renderer(),
|
||||
|
||||
@ -11,6 +11,7 @@ from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
from graphon.node_events import StreamCompletedEvent
|
||||
from graphon.nodes.protocols import ToolFileManagerProtocol
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from graphon.nodes.tool.tool_node import ToolNode
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
@ -60,8 +61,8 @@ def init_tool_node(config: dict):
|
||||
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
|
||||
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
node_id=str(uuid.uuid4()),
|
||||
config=ToolNodeData.model_validate(config["data"]),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tool_file_manager_factory=tool_file_manager_factory,
|
||||
|
||||
@ -8,7 +8,7 @@ from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
|
||||
@ -101,8 +101,8 @@ def _build_graph(
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
node_id="start",
|
||||
config=start_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@ -116,8 +116,8 @@ def _build_graph(
|
||||
],
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
id="human",
|
||||
config={"id": "human", "data": human_data.model_dump()},
|
||||
node_id="human",
|
||||
config=human_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
@ -130,8 +130,8 @@ def _build_graph(
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
node_id="end",
|
||||
config=end_data,
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
@ -123,9 +123,9 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
file_related_id = related_id
|
||||
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
file_id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.DOCUMENT,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=transfer_method,
|
||||
related_id=file_related_id,
|
||||
remote_url=remote_url,
|
||||
|
||||
@ -271,7 +271,7 @@ def _create_recipient(
|
||||
|
||||
|
||||
def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery:
|
||||
from core.workflow.human_input_compat import DeliveryMethodType
|
||||
from core.workflow.human_input_adapter import DeliveryMethodType
|
||||
from models.human_input import ConsoleDeliveryPayload
|
||||
|
||||
delivery = HumanInputDelivery(
|
||||
|
||||
@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@ -10,7 +10,7 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@ -31,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
file_list = [
|
||||
File(
|
||||
tenant_id="t1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="http://u",
|
||||
)
|
||||
|
||||
@ -314,8 +314,8 @@ def test_workflow_file_variable_with_signed_url():
|
||||
|
||||
# Create a File object with LOCAL_FILE transfer method (which generates signed URLs)
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
type=FileType.IMAGE,
|
||||
file_id="test_file_id",
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test_upload_file_id",
|
||||
filename="test.jpg",
|
||||
@ -370,8 +370,8 @@ def test_workflow_file_variable_remote_url():
|
||||
|
||||
# Create a File object with REMOTE_URL transfer method
|
||||
test_file = File(
|
||||
id="test_file_id",
|
||||
type=FileType.IMAGE,
|
||||
file_id="test_file_id",
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/test.jpg",
|
||||
filename="test.jpg",
|
||||
|
||||
@ -37,6 +37,8 @@ from controllers.service_api.app.conversation import (
|
||||
ConversationVariableUpdatePayload,
|
||||
)
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from graphon.variables import StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
@ -284,6 +286,32 @@ class TestConversationVariableResponseModels:
|
||||
assert response.created_at == int(created_at.timestamp())
|
||||
assert response.updated_at == int(created_at.timestamp())
|
||||
|
||||
def test_variable_response_normalizes_string_value_type_alias(self):
|
||||
response = ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "foo",
|
||||
"value_type": SegmentType.INTEGER.value,
|
||||
}
|
||||
)
|
||||
|
||||
assert response.value_type == "number"
|
||||
|
||||
def test_variable_response_normalizes_callable_exposed_type(self):
|
||||
response = ConversationVariableResponse.model_validate(
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "foo",
|
||||
"value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()),
|
||||
}
|
||||
)
|
||||
|
||||
assert response.value_type == "string"
|
||||
|
||||
def test_serialize_value_type_supports_segments_and_mappings(self):
|
||||
assert serialize_value_type(StringSegment(value="hello")) == "string"
|
||||
assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number"
|
||||
|
||||
def test_variable_pagination_response(self):
|
||||
response = ConversationVariableInfiniteScrollPaginationResponse.model_validate(
|
||||
{
|
||||
|
||||
@ -11,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||
def create_test_file(self, file_id: str = "test_file_1") -> File:
|
||||
"""Create a test File object"""
|
||||
return File(
|
||||
id=file_id,
|
||||
type=FileType.DOCUMENT,
|
||||
file_id=file_id,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related_123",
|
||||
filename=f"{file_id}.txt",
|
||||
|
||||
@ -7,11 +7,11 @@ import graphon.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow import node_factory as node_factory_module
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from graphon.entities.pause_reason import SchedulingPause
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus
|
||||
from graphon.graph import Graph
|
||||
@ -55,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def init_node_data(self, data):
|
||||
self._node_data = _StubToolNodeData.model_validate(data)
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
config: _StubToolNodeData,
|
||||
*,
|
||||
graph_init_params,
|
||||
graph_runtime_state,
|
||||
**_kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
@ -89,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]):
|
||||
|
||||
|
||||
def _patch_tool_node(mocker):
|
||||
original_create_node = DifyNodeFactory.create_node
|
||||
original_resolve_node_class = node_factory_module.resolve_workflow_node_class
|
||||
|
||||
def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node:
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
node_data = typed_node_config["data"]
|
||||
if node_data.type == BuiltinNodeTypes.TOOL:
|
||||
return _StubToolNode(
|
||||
id=str(typed_node_config["id"]),
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
return original_create_node(self, typed_node_config)
|
||||
def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
if node_type == BuiltinNodeTypes.TOOL:
|
||||
return _StubToolNode
|
||||
return original_resolve_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
|
||||
mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class)
|
||||
|
||||
|
||||
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
|
||||
|
||||
@ -26,8 +26,8 @@ def _build_file(
|
||||
extension: str | None = None,
|
||||
) -> File:
|
||||
return File(
|
||||
id="file-id",
|
||||
type=FileType.IMAGE,
|
||||
file_id="file-id",
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
reference=reference,
|
||||
remote_url=remote_url,
|
||||
@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M
|
||||
|
||||
assert runtime.multimodal_send_format == "url"
|
||||
|
||||
with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get:
|
||||
with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get:
|
||||
assert runtime.http_get("http://example", follow_redirects=False) == "response"
|
||||
mock_get.assert_called_once_with("http://example", follow_redirects=False)
|
||||
|
||||
|
||||
@ -8,8 +8,8 @@ from graphon.enums import BuiltinNodeTypes
|
||||
|
||||
|
||||
class DummyNode:
|
||||
def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs):
|
||||
self.id = id
|
||||
def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs):
|
||||
self.id = node_id
|
||||
self.config = config
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
|
||||
@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
|
||||
mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE)
|
||||
built = File(
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="tool_file_1",
|
||||
extension=".png",
|
||||
@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
file_in = File(
|
||||
type=FileType.DOCUMENT,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="tf",
|
||||
extension=".pdf",
|
||||
|
||||
@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None:
|
||||
|
||||
# Assert
|
||||
assert simple_provider.provider == "openai"
|
||||
assert simple_provider.label.en_US == "OpenAI"
|
||||
assert simple_provider.label.en_us == "OpenAI"
|
||||
assert simple_provider.supported_model_types == [ModelType.LLM]
|
||||
|
||||
|
||||
|
||||
@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType
|
||||
|
||||
def test_file():
|
||||
file = File(
|
||||
id="test-file",
|
||||
file_id="test-file",
|
||||
tenant_id="test-tenant-id",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="test-related-id",
|
||||
filename="image.png",
|
||||
@ -25,27 +25,21 @@ def test_file():
|
||||
assert file.size == 67
|
||||
|
||||
|
||||
def test_file_model_validate_accepts_legacy_tenant_id():
|
||||
data = {
|
||||
"id": "test-file",
|
||||
"tenant_id": "test-tenant-id",
|
||||
"type": "image",
|
||||
"transfer_method": "tool_file",
|
||||
"related_id": "test-related-id",
|
||||
"filename": "image.png",
|
||||
"extension": ".png",
|
||||
"mime_type": "image/png",
|
||||
"size": 67,
|
||||
"storage_key": "test-storage-key",
|
||||
"url": "https://example.com/image.png",
|
||||
# Extra legacy fields
|
||||
"tool_file_id": "tool-file-123",
|
||||
"upload_file_id": "upload-file-456",
|
||||
"datasource_file_id": "datasource-file-789",
|
||||
}
|
||||
def test_file_constructor_accepts_legacy_tenant_id():
|
||||
file = File(
|
||||
file_id="test-file",
|
||||
tenant_id="test-tenant-id",
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
tool_file_id="tool-file-123",
|
||||
filename="image.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=67,
|
||||
storage_key="test-storage-key",
|
||||
url="https://example.com/image.png",
|
||||
)
|
||||
|
||||
file = File.model_validate(data)
|
||||
|
||||
assert file.related_id == "test-related-id"
|
||||
assert file.related_id == "tool-file-123"
|
||||
assert file.storage_key == "test-storage-key"
|
||||
assert "tenant_id" not in file.model_dump()
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.helper.ssrf_proxy import (
|
||||
SSRF_DEFAULT_MAX_RETRIES,
|
||||
SSRFProxy,
|
||||
_get_user_provided_host_header,
|
||||
_to_graphon_http_response,
|
||||
graphon_ssrf_proxy,
|
||||
make_request,
|
||||
max_retries_exceeded_error,
|
||||
request_error,
|
||||
)
|
||||
|
||||
|
||||
@ -174,3 +180,56 @@ class TestFollowRedirectsParameter:
|
||||
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
|
||||
|
||||
def test_to_graphon_http_response_preserves_httpx_response_fields() -> None:
|
||||
response = httpx.Response(
|
||||
201,
|
||||
headers={"X-Test": "1"},
|
||||
content=b"payload",
|
||||
request=httpx.Request("GET", "https://example.com/resource"),
|
||||
)
|
||||
|
||||
wrapped = _to_graphon_http_response(response)
|
||||
|
||||
assert wrapped.status_code == 201
|
||||
assert wrapped.headers == {"x-test": "1", "content-length": "7"}
|
||||
assert wrapped.content == b"payload"
|
||||
assert wrapped.url == "https://example.com/resource"
|
||||
assert wrapped.reason_phrase == "Created"
|
||||
assert wrapped.text == "payload"
|
||||
|
||||
|
||||
def test_ssrf_proxy_exposes_expected_error_types() -> None:
|
||||
proxy = SSRFProxy()
|
||||
|
||||
assert proxy.max_retries_exceeded_error is max_retries_exceeded_error
|
||||
assert proxy.request_error is request_error
|
||||
assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error
|
||||
assert graphon_ssrf_proxy.request_error is request_error
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"])
|
||||
def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None:
|
||||
response = httpx.Response(
|
||||
200,
|
||||
headers={"X-Test": "1"},
|
||||
content=b"ok",
|
||||
request=httpx.Request("GET", "https://example.com/resource"),
|
||||
)
|
||||
|
||||
with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method:
|
||||
wrapped = getattr(graphon_ssrf_proxy, method_name)(
|
||||
"https://example.com/resource",
|
||||
max_retries=3,
|
||||
headers={"X-Test": "1"},
|
||||
)
|
||||
|
||||
mock_method.assert_called_once_with(
|
||||
url="https://example.com/resource",
|
||||
max_retries=3,
|
||||
headers={"X-Test": "1"},
|
||||
)
|
||||
assert wrapped.status_code == 200
|
||||
assert wrapped.url == "https://example.com/resource"
|
||||
assert wrapped.content == b"ok"
|
||||
|
||||
@ -13,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import (
|
||||
ProviderCredentialSchema,
|
||||
ProviderEntity,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
|
||||
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
|
||||
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
|
||||
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
|
||||
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
|
||||
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ class TestPluginModelRuntime:
|
||||
assert len(providers) == 1
|
||||
assert providers[0].provider == "langgenius/openai/openai"
|
||||
assert providers[0].provider_name == "openai"
|
||||
assert providers[0].label.en_US == "OpenAI"
|
||||
assert providers[0].label.en_us == "OpenAI"
|
||||
client.fetch_model_providers.assert_called_once_with("tenant")
|
||||
|
||||
def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None:
|
||||
|
||||
@ -466,7 +466,7 @@ class TestConverter:
|
||||
def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self):
|
||||
file_param = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/file.png",
|
||||
storage_key="",
|
||||
@ -499,14 +499,14 @@ class TestConverter:
|
||||
def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self):
|
||||
file_one = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.DOCUMENT,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/a.txt",
|
||||
storage_key="",
|
||||
)
|
||||
file_two = File(
|
||||
tenant_id="tenant-1",
|
||||
type=FileType.DOCUMENT,
|
||||
file_type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/b.txt",
|
||||
storage_key="",
|
||||
|
||||
@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
|
||||
files = [
|
||||
File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="",
|
||||
@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files():
|
||||
completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2")
|
||||
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query():
|
||||
memory = MagicMock(spec=TokenBufferMemory)
|
||||
prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)]
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch():
|
||||
transform = AdvancedPromptTransform()
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
storage_key="",
|
||||
|
||||
@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
from models.model import Conversation
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from graphon.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
# from graphon.model_runtime.entities.message_entities import UserPromptMessage
|
||||
# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule
|
||||
# from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||
# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
# from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
import gc
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import UserDict
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from docx import Document
|
||||
@ -354,15 +356,46 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path):
|
||||
WordExtractor("not-a-file", "tenant", "user")
|
||||
|
||||
|
||||
def test_del_closes_temp_file():
|
||||
def test_close_closes_temp_file():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor._closed = False
|
||||
extractor.temp_file = MagicMock()
|
||||
|
||||
WordExtractor.__del__(extractor)
|
||||
extractor.close()
|
||||
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
|
||||
|
||||
def test_close_is_idempotent():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor._closed = False
|
||||
extractor.temp_file = MagicMock()
|
||||
|
||||
extractor.close()
|
||||
extractor.close()
|
||||
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
|
||||
|
||||
def test_close_handles_async_close_mock():
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor._closed = False
|
||||
extractor.temp_file = MagicMock()
|
||||
extractor.temp_file.close = AsyncMock()
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
extractor.close()
|
||||
gc.collect()
|
||||
|
||||
extractor.temp_file.close.assert_called_once()
|
||||
assert not [
|
||||
warning
|
||||
for warning in caught
|
||||
if issubclass(warning.category, RuntimeWarning) and "AsyncMockMixin._execute_mock_call" in str(warning.message)
|
||||
]
|
||||
|
||||
|
||||
def test_extract_images_handles_invalid_external_cases(monkeypatch):
|
||||
class FakeTargetRef:
|
||||
def __contains__(self, item):
|
||||
|
||||
@ -14,7 +14,7 @@ from core.repositories.human_input_repository import (
|
||||
HumanInputFormSubmissionRepository,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@ -21,7 +21,7 @@ from core.repositories.human_input_repository import (
|
||||
_InvalidTimeoutStatusError,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.human_input_compat import (
|
||||
from core.workflow.human_input_adapter import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
|
||||
@ -6,9 +6,9 @@ from models.workflow import Workflow
|
||||
|
||||
def test_file_to_dict():
|
||||
file = File(
|
||||
id="file1",
|
||||
file_id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
file_type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
storage_key="storage_key",
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import dataclasses
|
||||
from typing import Annotated
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Discriminator, Tag
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
@ -12,17 +13,18 @@ from graphon.runtime import VariablePool
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
ArrayBooleanSegment,
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentUnion,
|
||||
StringSegment,
|
||||
get_segment_discriminator,
|
||||
)
|
||||
@ -47,6 +49,26 @@ from graphon.variables.variables import (
|
||||
StringVariable,
|
||||
Variable,
|
||||
)
|
||||
from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup
|
||||
|
||||
type SegmentUnion = Annotated[
|
||||
(
|
||||
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringSegment, Tag(SegmentType.STRING)]
|
||||
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
|
||||
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
|
||||
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
|
||||
| Annotated[FileSegment, Tag(SegmentType.FILE)]
|
||||
| Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)]
|
||||
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
|
||||
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
|
||||
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
|
||||
def _build_variable_pool(
|
||||
@ -123,7 +145,7 @@ def create_test_file(
|
||||
) -> File:
|
||||
"""Factory function to create File objects for testing"""
|
||||
return File(
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad:
|
||||
assert restored == model
|
||||
|
||||
def test_all_segments_serialization(self):
|
||||
"""Test serialization/deserialization of all segment types"""
|
||||
"""Test file-aware segment serialization through Dify's model boundary."""
|
||||
# Create one instance of each segment type
|
||||
test_file = create_test_file()
|
||||
|
||||
@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad:
|
||||
# Test serialization and deserialization
|
||||
model = _Segments(segments=all_segments)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Segments.model_validate_json(json_str)
|
||||
loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
|
||||
|
||||
# Verify all segments are preserved
|
||||
assert len(loaded.segments) == len(all_segments)
|
||||
@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad:
|
||||
assert loaded_segment.value == original.value
|
||||
|
||||
def test_all_variables_serialization(self):
|
||||
"""Test serialization/deserialization of all variable types"""
|
||||
"""Test file-aware variable serialization through Dify's model boundary."""
|
||||
# Create one instance of each variable type
|
||||
test_file = create_test_file()
|
||||
|
||||
@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad:
|
||||
# Test serialization and deserialization
|
||||
model = _Variables(variables=all_variables)
|
||||
json_str = model.model_dump_json()
|
||||
loaded = _Variables.model_validate_json(json_str)
|
||||
loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str)))
|
||||
|
||||
# Verify all variables are preserved
|
||||
assert len(loaded.variables) == len(all_variables)
|
||||
|
||||
@ -35,7 +35,7 @@ def create_test_file(
|
||||
"""Factory function to create File objects for testing."""
|
||||
return File(
|
||||
tenant_id="test-tenant",
|
||||
type=file_type,
|
||||
file_type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
"""
|
||||
Mock node factory for testing workflows with third-party service dependencies.
|
||||
"""Mock node factory for third-party-service workflow tests.
|
||||
|
||||
This module provides a MockNodeFactory that automatically detects and mocks nodes
|
||||
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
|
||||
The factory follows the same config adaptation path as production
|
||||
`DifyNodeFactory.create_node()`, but swaps selected node classes for mock
|
||||
implementations before instantiation.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
@ -82,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
:param node_config: Node configuration dictionary
|
||||
:return: Node instance (real or mocked)
|
||||
"""
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_type = node_data.type
|
||||
|
||||
# Check if this node type should be mocked
|
||||
if node_type in self._mock_node_types:
|
||||
node_id = typed_node_config["id"]
|
||||
|
||||
# Create mock node instance
|
||||
mock_class = self._mock_node_types[node_type]
|
||||
resolved_node_data = self._validate_resolved_node_data(mock_class, node_data)
|
||||
if node_type == BuiltinNodeTypes.CODE:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -104,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
)
|
||||
elif node_type == BuiltinNodeTypes.HTTP_REQUEST:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -120,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
|
||||
}:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -130,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
)
|
||||
else:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=typed_node_config,
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
@ -140,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
return mock_instance
|
||||
|
||||
# For non-mocked node types, use parent implementation
|
||||
return super().create_node(typed_node_config)
|
||||
return super().create_node(node_config)
|
||||
|
||||
def should_mock_node(self, node_type: NodeType) -> bool:
|
||||
"""
|
||||
|
||||
@ -55,13 +55,14 @@ class MockNodeMixin:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
config: Any,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
mock_config: Optional["MockConfig"] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> None:
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
|
||||
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||
@ -96,7 +97,7 @@ class MockNodeMixin:
|
||||
kwargs.setdefault("message_transformer", MagicMock())
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
node_id=node_id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user