mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:57:03 +08:00
Merge remote-tracking branch 'myori/main' into feat/collaboration2
This commit is contained in:
commit
59e752dcd3
@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_rag_pipeline
|
||||
# def post(self, pipeline: Pipeline, node_id: str):
|
||||
# """
|
||||
# Run rag pipeline datasource
|
||||
# """
|
||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||
# if not current_user.has_edit_permission:
|
||||
# raise Forbidden()
|
||||
#
|
||||
# if not isinstance(current_user, Account):
|
||||
# raise Forbidden()
|
||||
#
|
||||
# parser = (reqparse.RequestParser()
|
||||
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
#
|
||||
# job_id = args.get("job_id")
|
||||
# if job_id == None:
|
||||
# raise ValueError("missing job_id")
|
||||
# datasource_type = args.get("datasource_type")
|
||||
# if datasource_type == None:
|
||||
# raise ValueError("missing datasource_type")
|
||||
#
|
||||
# rag_pipeline_service = RagPipelineService()
|
||||
# result = rag_pipeline_service.run_datasource_workflow_node_status(
|
||||
# pipeline=pipeline,
|
||||
# node_id=node_id,
|
||||
# job_id=job_id,
|
||||
# account=current_user,
|
||||
# datasource_type=datasource_type,
|
||||
# is_published=True
|
||||
# )
|
||||
#
|
||||
# return result
|
||||
|
||||
|
||||
# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_rag_pipeline
|
||||
# def post(self, pipeline: Pipeline, node_id: str):
|
||||
# """
|
||||
# Run rag pipeline datasource
|
||||
# """
|
||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||
# if not current_user.has_edit_permission:
|
||||
# raise Forbidden()
|
||||
#
|
||||
# if not isinstance(current_user, Account):
|
||||
# raise Forbidden()
|
||||
#
|
||||
# parser = (reqparse.RequestParser()
|
||||
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
#
|
||||
# job_id = args.get("job_id")
|
||||
# if job_id == None:
|
||||
# raise ValueError("missing job_id")
|
||||
# datasource_type = args.get("datasource_type")
|
||||
# if datasource_type == None:
|
||||
# raise ValueError("missing datasource_type")
|
||||
#
|
||||
# rag_pipeline_service = RagPipelineService()
|
||||
# result = rag_pipeline_service.run_datasource_workflow_node_status(
|
||||
# pipeline=pipeline,
|
||||
# node_id=node_id,
|
||||
# job_id=job_id,
|
||||
# account=current_user,
|
||||
# datasource_type=datasource_type,
|
||||
# is_published=False
|
||||
# )
|
||||
#
|
||||
# return result
|
||||
#
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
|
||||
@ -7,7 +7,8 @@ import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_user_id=current_user.id,
|
||||
)
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
@ -112,10 +119,7 @@ class HumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
@ -135,8 +139,8 @@ class HumanInputFormApi(Resource):
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_end_user_id=None,
|
||||
# submission_end_user_id=_end_user.id,
|
||||
)
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.enums import WorkflowType
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
@ -22,7 +21,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
@ -265,22 +264,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
# graph_config["nodes"] = real_run_nodes
|
||||
# graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
# Create explicit graph init context for Graph.init.
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if start_node_id is None:
|
||||
|
||||
@ -3,7 +3,6 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDictAdapter
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.graph import Graph
|
||||
@ -67,7 +66,12 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.node_factory import (
|
||||
DifyGraphInitContext,
|
||||
DifyNodeFactory,
|
||||
get_default_root_node_id,
|
||||
resolve_workflow_node_class,
|
||||
)
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
default_system_variables,
|
||||
@ -127,24 +131,25 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
# Create explicit graph init context for Graph.init.
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
@ -289,22 +294,23 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs]
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
# Create explicit graph init context for Graph.init.
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
graph_config=graph_config,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=self._app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
|
||||
@ -146,7 +146,7 @@ def discover_protected_resource_metadata(
|
||||
return ProtectedResourceMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
except (RequestError, ValidationError, json.JSONDecodeError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata(
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
elif response.status_code == 404:
|
||||
continue # Try next URL
|
||||
except (RequestError, ValidationError):
|
||||
except (RequestError, ValidationError, json.JSONDecodeError):
|
||||
continue # Try next URL
|
||||
|
||||
return None
|
||||
@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except RequestError:
|
||||
except (RequestError, json.JSONDecodeError, IndexError):
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
|
||||
@ -61,27 +61,28 @@ class TokenBufferMemory:
|
||||
:param is_user_message: whether this is a user message
|
||||
:return: PromptMessage
|
||||
"""
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
match self.conversation.mode:
|
||||
case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
if not message.workflow_run_id:
|
||||
raise ValueError("Workflow run ID not found")
|
||||
if not message.workflow_run_id:
|
||||
raise ValueError("Workflow run ID not found")
|
||||
|
||||
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
else:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
case _:
|
||||
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
if file_extra_config and app_record:
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -19,6 +20,16 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
|
||||
hosts: list[str]
|
||||
verify_certs: bool
|
||||
ssl_show_warn: bool
|
||||
request_timeout: int
|
||||
retry_on_timeout: bool
|
||||
max_retries: int
|
||||
basic_auth: tuple[str, str]
|
||||
|
||||
|
||||
def create_ssl_context() -> ssl.SSLContext:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
|
||||
raise ValueError("config HOSTS is required")
|
||||
return values
|
||||
|
||||
def to_elasticsearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": self.hosts.split(","),
|
||||
"verify_certs": False,
|
||||
"ssl_show_warn": False,
|
||||
"request_timeout": 30000,
|
||||
"retry_on_timeout": True,
|
||||
"max_retries": 10,
|
||||
}
|
||||
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
|
||||
params = HuaweiElasticsearchParamsDict(
|
||||
hosts=self.hosts.split(","),
|
||||
verify_certs=False,
|
||||
ssl_show_warn=False,
|
||||
request_timeout=30000,
|
||||
retry_on_timeout=True,
|
||||
max_retries=10,
|
||||
)
|
||||
if self.username and self.password:
|
||||
params["basic_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
|
||||
UGC_INDEX_PREFIX = "ugc_index"
|
||||
|
||||
|
||||
class LindormOpenSearchParamsDict(TypedDict, total=False):
|
||||
hosts: str | None
|
||||
use_ssl: bool
|
||||
pool_maxsize: int
|
||||
timeout: int
|
||||
http_auth: tuple[str, str]
|
||||
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str | None
|
||||
username: str | None = None
|
||||
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {
|
||||
"hosts": self.hosts,
|
||||
"use_ssl": False,
|
||||
"pool_maxsize": 128,
|
||||
"timeout": 30,
|
||||
}
|
||||
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
|
||||
params = LindormOpenSearchParamsDict(
|
||||
hosts=self.hosts,
|
||||
use_ssl=False,
|
||||
pool_maxsize=128,
|
||||
timeout=30,
|
||||
)
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from configs.middleware.vdb.opensearch_config import AuthMethod
|
||||
@ -21,6 +22,20 @@ from models.dataset import Dataset
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _OpenSearchHostDict(TypedDict):
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
class OpenSearchParamsDict(TypedDict, total=False):
|
||||
hosts: list[_OpenSearchHostDict]
|
||||
use_ssl: bool
|
||||
verify_certs: bool
|
||||
connection_class: type
|
||||
pool_maxsize: int
|
||||
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
|
||||
service=self.aws_service, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.verify_certs,
|
||||
"connection_class": Urllib3HttpConnection,
|
||||
"pool_maxsize": 20,
|
||||
}
|
||||
def to_opensearch_params(self) -> OpenSearchParamsDict:
|
||||
params = OpenSearchParamsDict(
|
||||
hosts=[{"host": self.host, "port": self.port}],
|
||||
use_ssl=self.secure,
|
||||
verify_certs=self.verify_certs,
|
||||
connection_class=Urllib3HttpConnection,
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
@ -19,5 +19,18 @@ def remove_leading_symbols(text: str) -> str:
|
||||
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
|
||||
pattern = re.compile(
|
||||
r"""
|
||||
^
|
||||
(?:
|
||||
[\u2000-\u2025] # General Punctuation: spaces, quotes, dashes
|
||||
| [\u2027-\u206F] # General Punctuation: ellipsis, underscores, etc.
|
||||
| [\u2E00-\u2E7F] # Supplemental Punctuation: medieval, ancient marks
|
||||
| [\u3000-\u300F] # CJK Punctuation: 、。〃「」『》』 (excludes 【】)
|
||||
| [\u3012-\u303F] # CJK Punctuation: 〖〗〔〕〘〙〚〛〜 etc.
|
||||
| ["#$%&'()*+,./:;<=>?@^_`~] # ASCII punctuation (excludes []【】)
|
||||
)+
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
return re.sub(pattern, "", text)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import importlib
|
||||
import pkgutil
|
||||
from collections.abc import Callable, Iterator, Mapping, MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast, final, override
|
||||
|
||||
@ -67,6 +68,31 @@ _START_NODE_TYPES: frozenset[NodeType] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DifyGraphInitContext:
|
||||
"""Explicit graph-init values owned by the workflow layer.
|
||||
|
||||
Dify is gradually removing direct `GraphInitParams` construction from its
|
||||
production call sites. Keep the translation here until `graphon` exposes an
|
||||
equivalent explicit API.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
graph_config: Mapping[str, Any]
|
||||
run_context: Mapping[str, Any]
|
||||
call_depth: int
|
||||
|
||||
def to_graph_init_params(self) -> "GraphInitParams":
|
||||
from graphon.entities import GraphInitParams
|
||||
|
||||
return GraphInitParams(
|
||||
workflow_id=self.workflow_id,
|
||||
graph_config=self.graph_config,
|
||||
run_context=self.run_context,
|
||||
call_depth=self.call_depth,
|
||||
)
|
||||
|
||||
|
||||
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
|
||||
package = importlib.import_module(package_name)
|
||||
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
|
||||
@ -237,6 +263,19 @@ class DifyNodeFactory(NodeFactory):
|
||||
Default implementation of NodeFactory that resolves node classes from the live registry.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_graph_init_context(
|
||||
cls,
|
||||
*,
|
||||
graph_init_context: DifyGraphInitContext,
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
) -> "DifyNodeFactory":
|
||||
"""Bridge Dify's explicit init context into the current `graphon` API."""
|
||||
return cls(
|
||||
graph_init_params=graph_init_context.to_graph_init_params(),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
|
||||
@ -29,7 +29,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
def post_init(self) -> None:
|
||||
from core.workflow.node_runtime import DifyFileReferenceFactory
|
||||
|
||||
self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context)
|
||||
self._file_reference_factory = DifyFileReferenceFactory(self.run_context)
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
||||
@ -24,7 +24,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class
|
||||
from core.workflow.node_factory import (
|
||||
DifyGraphInitContext,
|
||||
DifyNodeFactory,
|
||||
is_start_node_type,
|
||||
resolve_workflow_node_class,
|
||||
)
|
||||
from core.workflow.system_variables import (
|
||||
default_system_variables,
|
||||
get_node_creation_preload_selectors,
|
||||
@ -251,17 +256,18 @@ class WorkflowEntry:
|
||||
node_version = str(node_config_data.version)
|
||||
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
# init graph context and runtime state
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
@ -313,8 +319,8 @@ class WorkflowEntry:
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
@ -409,17 +415,18 @@ class WorkflowEntry:
|
||||
variable_pool = VariablePool()
|
||||
add_variables_to_pool(variable_pool, default_system_variables())
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
# init graph context and runtime state
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id="",
|
||||
graph_config=graph_dict,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
@ -430,8 +437,8 @@ class WorkflowEntry:
|
||||
|
||||
# init workflow run state
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
|
||||
@ -5,12 +5,30 @@ from typing import Any
|
||||
import pytz # type: ignore[import-untyped]
|
||||
from celery import Celery, Task
|
||||
from celery.schedules import crontab
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
class _CelerySentinelKwargsDict(TypedDict):
|
||||
socket_timeout: float | None
|
||||
password: str | None
|
||||
|
||||
|
||||
class CelerySentinelTransportDict(TypedDict):
|
||||
master_name: str | None
|
||||
sentinel_kwargs: _CelerySentinelKwargsDict
|
||||
|
||||
|
||||
class CelerySSLOptionsDict(TypedDict):
|
||||
ssl_cert_reqs: int
|
||||
ssl_ca_certs: str | None
|
||||
ssl_certfile: str | None
|
||||
ssl_keyfile: str | None
|
||||
|
||||
|
||||
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
|
||||
"""Get SSL configuration for Celery broker/backend connections."""
|
||||
# Only apply SSL if we're using Redis as broker/backend
|
||||
if not dify_config.BROKER_USE_SSL:
|
||||
@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None:
|
||||
|
||||
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
|
||||
|
||||
ssl_options = {
|
||||
"ssl_cert_reqs": ssl_cert_reqs,
|
||||
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
|
||||
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
|
||||
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
|
||||
}
|
||||
|
||||
return ssl_options
|
||||
return CelerySSLOptionsDict(
|
||||
ssl_cert_reqs=ssl_cert_reqs,
|
||||
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS,
|
||||
ssl_certfile=dify_config.REDIS_SSL_CERTFILE,
|
||||
ssl_keyfile=dify_config.REDIS_SSL_KEYFILE,
|
||||
)
|
||||
|
||||
|
||||
def get_celery_broker_transport_options() -> dict[str, Any]:
|
||||
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
|
||||
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
|
||||
if dify_config.CELERY_USE_SENTINEL:
|
||||
return {
|
||||
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
"sentinel_kwargs": {
|
||||
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
"password": dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
},
|
||||
}
|
||||
return CelerySentinelTransportDict(
|
||||
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
|
||||
sentinel_kwargs=_CelerySentinelKwargsDict(
|
||||
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
|
||||
password=dify_config.CELERY_SENTINEL_PASSWORD,
|
||||
),
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@ -674,28 +674,24 @@ class AppModelConfig(TypeBase):
|
||||
def suggested_questions_list(self) -> list[str]:
|
||||
return json.loads(self.suggested_questions) if self.suggested_questions else []
|
||||
|
||||
def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled})
|
||||
|
||||
@property
|
||||
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig,
|
||||
json.loads(self.suggested_questions_after_answer)
|
||||
if self.suggested_questions_after_answer
|
||||
else {"enabled": False},
|
||||
)
|
||||
return self._get_enabled_config(self.suggested_questions_after_answer)
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
|
||||
return self._get_enabled_config(self.speech_to_text)
|
||||
|
||||
@property
|
||||
def text_to_speech_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
|
||||
return self._get_enabled_config(self.text_to_speech)
|
||||
|
||||
@property
|
||||
def retriever_resource_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
)
|
||||
return self._get_enabled_config(self.retriever_resource, default_enabled=True)
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
@ -722,7 +718,7 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
|
||||
return self._get_enabled_config(self.more_like_this)
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
|
||||
@ -902,7 +898,7 @@ class InstalledApp(TypeBase):
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
class TrialApp(TypeBase):
|
||||
__tablename__ = "trial_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
|
||||
@ -911,18 +907,26 @@ class TrialApp(Base):
|
||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, default=gen_uuidv4_string)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
insert_default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
class AccountTrialAppRecord(TypeBase):
|
||||
__tablename__ = "account_trial_app_records"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
|
||||
@ -930,11 +934,19 @@ class AccountTrialAppRecord(Base):
|
||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||
)
|
||||
id = mapped_column(StringUUID, default=gen_uuidv4_string)
|
||||
account_id = mapped_column(StringUUID, nullable=False)
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
|
||||
)
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
insert_default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
|
||||
@ -66,12 +66,15 @@ def build_file_from_stored_mapping(
|
||||
record_id = resolve_file_record_id(mapping)
|
||||
transfer_method = FileTransferMethod.value_of(mapping["transfer_method"])
|
||||
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE and record_id:
|
||||
mapping["tool_file_id"] = record_id
|
||||
elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id:
|
||||
mapping["upload_file_id"] = record_id
|
||||
elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id:
|
||||
mapping["datasource_file_id"] = record_id
|
||||
match transfer_method:
|
||||
case FileTransferMethod.TOOL_FILE if record_id:
|
||||
mapping["tool_file_id"] = record_id
|
||||
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id:
|
||||
mapping["upload_file_id"] = record_id
|
||||
case FileTransferMethod.DATASOURCE_FILE if record_id:
|
||||
mapping["datasource_file_id"] = record_id
|
||||
case _:
|
||||
pass
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
|
||||
remote_url = mapping.get("remote_url")
|
||||
|
||||
@ -173,7 +173,7 @@ dev = [
|
||||
"sseclient-py>=1.8.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pyrefly>=0.59.1",
|
||||
"pyrefly>=0.60.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
|
||||
@ -467,61 +467,67 @@ class AppDslService:
|
||||
)
|
||||
|
||||
# Initialize app based on mode
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for workflow/advanced chat app")
|
||||
match app_mode:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for workflow/advanced chat app")
|
||||
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj)
|
||||
for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
if current_draft_workflow:
|
||||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
|
||||
]
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
|
||||
# Initialize model config
|
||||
model_config = data.get("model_config")
|
||||
if not model_config or not isinstance(model_config, dict):
|
||||
raise ValueError("Missing model_config for chat/agent-chat/completion app")
|
||||
# Initialize or update model config
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
workflow_service = WorkflowService()
|
||||
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
if current_draft_workflow:
|
||||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (
|
||||
decrypted_id := self.decrypt_dataset_id(
|
||||
encrypted_data=dataset_id, tenant_id=app.tenant_id
|
||||
)
|
||||
)
|
||||
]
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION:
|
||||
# Initialize model config
|
||||
model_config = data.get("model_config")
|
||||
if not model_config or not isinstance(model_config, dict):
|
||||
raise ValueError("Missing model_config for chat/agent-chat/completion app")
|
||||
# Initialize or update model config
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
self._session.add(app_model_config)
|
||||
app_model_config_was_updated.send(app, app_model_config=app_model_config)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
self._session.add(app_model_config)
|
||||
app_model_config_was_updated.send(app, app_model_config=app_model_config)
|
||||
case _:
|
||||
raise ValueError("Invalid app mode")
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -132,8 +132,8 @@ class FileService:
|
||||
return file_size <= file_size_limit
|
||||
|
||||
def get_file_base64(self, file_id: str) -> str:
|
||||
upload_file = (
|
||||
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = self._session_maker(expire_on_commit=False).scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id).limit(1)
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -178,7 +178,7 @@ class FileService:
|
||||
Return a short text preview extracted from a document file.
|
||||
"""
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -200,7 +200,7 @@ class FileService:
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -220,7 +220,7 @@ class FileService:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -231,7 +231,7 @@ class FileService:
|
||||
|
||||
def get_public_image_preview(self, file_id: str):
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -247,7 +247,7 @@ class FileService:
|
||||
|
||||
def get_file_content(self, file_id: str) -> str:
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
|
||||
class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with session_factory.create_session() as session:
|
||||
return session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
@ -24,7 +23,7 @@ class PluginAutoUpgradeService:
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
) -> bool:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with session_factory.create_session() as session:
|
||||
exist_strategy = session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
@ -51,7 +50,7 @@ class PluginAutoUpgradeService:
|
||||
|
||||
@staticmethod
|
||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with session_factory.create_session() as session:
|
||||
exist_strategy = session.scalar(
|
||||
select(TenantPluginAutoUpgradeStrategy)
|
||||
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants.languages import languages
|
||||
@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
||||
|
||||
class RecommendedAppItemDict(TypedDict):
|
||||
id: str
|
||||
app: App | None
|
||||
app_id: str
|
||||
description: Any
|
||||
copyright: Any
|
||||
privacy_policy: Any
|
||||
custom_disclaimer: str
|
||||
category: str
|
||||
position: int
|
||||
is_listed: bool
|
||||
|
||||
|
||||
class RecommendedAppsResultDict(TypedDict):
|
||||
recommended_apps: list[RecommendedAppItemDict]
|
||||
categories: list[str]
|
||||
|
||||
|
||||
class RecommendedAppDetailDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
icon: Any
|
||||
icon_background: str | None
|
||||
mode: str
|
||||
export_data: str
|
||||
|
||||
|
||||
class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
"""
|
||||
Retrieval recommended app from database
|
||||
"""
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict:
|
||||
result = self.fetch_recommended_apps_from_db(language)
|
||||
return result
|
||||
|
||||
def get_recommend_app_detail(self, app_id: str):
|
||||
def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None:
|
||||
result = self.fetch_recommended_app_detail_from_db(app_id)
|
||||
return result
|
||||
|
||||
@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return RecommendAppType.DATABASE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_db(cls, language: str):
|
||||
def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict:
|
||||
"""
|
||||
Fetch recommended apps from db.
|
||||
:param language: language
|
||||
@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
).all()
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
recommended_apps_result: list[RecommendedAppItemDict] = []
|
||||
for recommended_app in recommended_apps:
|
||||
app = recommended_app.app
|
||||
if not app or not app.is_public:
|
||||
@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
if not site:
|
||||
continue
|
||||
|
||||
recommended_app_result = {
|
||||
recommended_app_result: RecommendedAppItemDict = {
|
||||
"id": recommended_app.id,
|
||||
"app": recommended_app.app,
|
||||
"app_id": recommended_app.app_id,
|
||||
@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
|
||||
categories.add(recommended_app.category)
|
||||
|
||||
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
|
||||
return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories))
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
|
||||
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None:
|
||||
"""
|
||||
Fetch recommended app detail from db.
|
||||
:param app_id: App ID
|
||||
@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
if not app_model or not app_model.is_public:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": app_model.id,
|
||||
"name": app_model.name,
|
||||
"icon": app_model.icon,
|
||||
"icon_background": app_model.icon_background,
|
||||
"mode": app_model.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=app_model),
|
||||
}
|
||||
return RecommendedAppDetailDict(
|
||||
id=app_model.id,
|
||||
name=app_model.name,
|
||||
icon=app_model.icon,
|
||||
icon_background=app_model.icon_background,
|
||||
mode=app_model.mode,
|
||||
export_data=AppDslService.export_dsl(app_model=app_model),
|
||||
)
|
||||
|
||||
@ -104,32 +104,32 @@ class WebhookService:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first()
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1)
|
||||
)
|
||||
if not webhook_trigger:
|
||||
raise ValueError(f"Webhook not found: {webhook_id}")
|
||||
|
||||
if is_debug:
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter(
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
# Check if the corresponding AppTrigger exists
|
||||
app_trigger = (
|
||||
session.query(AppTrigger)
|
||||
.filter(
|
||||
app_trigger = session.scalar(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.app_id == webhook_trigger.app_id,
|
||||
AppTrigger.node_id == webhook_trigger.node_id,
|
||||
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not app_trigger:
|
||||
@ -146,14 +146,14 @@ class WebhookService:
|
||||
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
|
||||
|
||||
# Get workflow
|
||||
workflow = (
|
||||
session.query(Workflow)
|
||||
.filter(
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.app_id == webhook_trigger.app_id,
|
||||
Workflow.version != Workflow.VERSION_DRAFT,
|
||||
)
|
||||
.order_by(Workflow.created_at.desc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
|
||||
|
||||
@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from graphon.entities import WorkflowNodeExecution
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import (
|
||||
@ -48,7 +48,12 @@ from core.workflow.human_input_compat import (
|
||||
normalize_human_input_node_data_for_graph,
|
||||
parse_human_input_delivery_methods,
|
||||
)
|
||||
from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type
|
||||
from core.workflow.node_factory import (
|
||||
LATEST_VERSION,
|
||||
DifyGraphInitContext,
|
||||
get_node_type_classes_mapping,
|
||||
is_start_node_type,
|
||||
)
|
||||
from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
@ -1204,18 +1209,20 @@ class WorkflowService:
|
||||
node_config: NodeConfigDict,
|
||||
variable_pool: VariablePool,
|
||||
) -> HumanInputNode:
|
||||
graph_init_params = GraphInitParams(
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=account.id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
run_context=build_dify_run_context(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
user_id=account.id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_init_params = graph_init_context.to_graph_init_params()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
@ -1225,7 +1232,7 @@ class WorkflowService:
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
|
||||
runtime=DifyHumanInputNodeRuntime(run_context),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id.in_(dataset_documents_ids))
|
||||
.values(indexing_status="indexing")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
session.query(DocumentSegment)
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if segments:
|
||||
documents = []
|
||||
multimodal_documents = []
|
||||
@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
index_processor.load(
|
||||
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
|
||||
)
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="completed")
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
session.execute(
|
||||
update(DatasetDocument)
|
||||
.where(DatasetDocument.id == dataset_document.id)
|
||||
.values(indexing_status="error", error=str(e))
|
||||
)
|
||||
session.commit()
|
||||
else:
|
||||
|
||||
@ -862,6 +862,15 @@ class TestAuthOrchestration:
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
# JSONDecodeError (non-JSON 200 response)
|
||||
mock_get.side_effect = None
|
||||
bad_json_response = Mock()
|
||||
bad_json_response.status_code = 200
|
||||
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_get.return_value = bad_json_response
|
||||
result = discover_protected_resource_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_authorization_server_metadata(self, mock_get):
|
||||
# Success
|
||||
@ -892,6 +901,14 @@ class TestAuthOrchestration:
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
# JSONDecodeError (non-JSON 200 response)
|
||||
bad_json_response = Mock()
|
||||
bad_json_response.status_code = 200
|
||||
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_get.return_value = bad_json_response
|
||||
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
|
||||
assert result is None
|
||||
|
||||
def test_get_effective_scope(self):
|
||||
prm = ProtectedResourceMetadata(
|
||||
resource="https://api.example.com",
|
||||
@ -997,6 +1014,24 @@ class TestAuthOrchestration:
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
# Case 6: JSONDecodeError (non-JSON 200 response)
|
||||
mock_get.side_effect = None
|
||||
bad_json_res = Mock()
|
||||
bad_json_res.status_code = 200
|
||||
bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_get.return_value = bad_json_res
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
assert url == ""
|
||||
|
||||
# Case 7: Empty authorization_servers array (IndexError)
|
||||
empty_res = Mock()
|
||||
empty_res.status_code = 200
|
||||
empty_res.json.return_value = {"authorization_servers": []}
|
||||
mock_get.return_value = empty_res
|
||||
supported, url = check_support_resource_discovery("https://api")
|
||||
assert supported is False
|
||||
|
||||
def test_discover_oauth_metadata(self):
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
|
||||
@ -110,6 +110,34 @@ class TestFetchMemory:
|
||||
)
|
||||
|
||||
|
||||
class TestDifyGraphInitContext:
|
||||
def test_to_graph_init_params_preserves_explicit_values(self):
|
||||
run_context = {
|
||||
DIFY_RUN_CONTEXT_KEY: DifyRunContext(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
"extra": "value",
|
||||
}
|
||||
graph_config = {"nodes": [], "edges": []}
|
||||
graph_init_context = node_factory.DifyGraphInitContext(
|
||||
workflow_id="workflow-id",
|
||||
graph_config=graph_config,
|
||||
run_context=run_context,
|
||||
call_depth=2,
|
||||
)
|
||||
|
||||
result = graph_init_context.to_graph_init_params()
|
||||
|
||||
assert result.workflow_id == "workflow-id"
|
||||
assert result.graph_config == graph_config
|
||||
assert result.run_context == run_context
|
||||
assert result.call_depth == 2
|
||||
|
||||
|
||||
class TestDefaultWorkflowCodeExecutor:
|
||||
def test_execute_delegates_to_code_executor(self, monkeypatch):
|
||||
executor = node_factory.DefaultWorkflowCodeExecutor()
|
||||
@ -172,6 +200,23 @@ class TestCodeExecutorJinja2TemplateRenderer:
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_from_graph_init_context_translates_before_init(self):
|
||||
graph_init_context = MagicMock()
|
||||
graph_init_context.to_graph_init_params.return_value = sentinel.graph_init_params
|
||||
|
||||
with patch.object(node_factory.DifyNodeFactory, "__init__", return_value=None) as init:
|
||||
factory = node_factory.DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(factory, node_factory.DifyNodeFactory)
|
||||
graph_init_context.to_graph_init_params.assert_called_once_with()
|
||||
init.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
graph_runtime_state = sentinel.graph_runtime_state
|
||||
|
||||
@ -349,7 +349,7 @@ class TestWorkflowEntrySingleStepRun:
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(
|
||||
workflow_entry,
|
||||
"GraphRuntimeState",
|
||||
@ -358,7 +358,7 @@ class TestWorkflowEntrySingleStepRun:
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeLLMNode),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
|
||||
patch.object(workflow_entry, "load_into_variable_pool"),
|
||||
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
|
||||
patch.object(
|
||||
@ -412,12 +412,12 @@ class TestWorkflowEntrySingleStepRun:
|
||||
raise NotImplementedError
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
|
||||
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
@ -481,12 +481,12 @@ class TestWorkflowEntrySingleStepRun:
|
||||
return {"question": ["node", "question"]}
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeDatasourceNode),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
|
||||
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
|
||||
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
|
||||
patch.object(
|
||||
@ -541,12 +541,12 @@ class TestWorkflowEntrySingleStepRun:
|
||||
return "1"
|
||||
|
||||
with (
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
|
||||
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
|
||||
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
|
||||
patch.object(workflow_entry, "add_node_inputs_to_pool"),
|
||||
patch.object(workflow_entry, "load_into_variable_pool"),
|
||||
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
|
||||
@ -651,14 +651,18 @@ class TestWorkflowEntryHelpers:
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
|
||||
patch.object(workflow_entry, "add_variables_to_pool") as add_variables_to_pool,
|
||||
patch.object(
|
||||
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params
|
||||
) as graph_init_params,
|
||||
workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context
|
||||
) as graph_init_context_cls,
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(
|
||||
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
|
||||
) as build_dify_run_context,
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls,
|
||||
patch.object(
|
||||
workflow_entry.DifyNodeFactory,
|
||||
"from_graph_init_context",
|
||||
return_value=dify_node_factory,
|
||||
) as dify_node_factory_cls,
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
@ -688,7 +692,7 @@ class TestWorkflowEntryHelpers:
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
graph_init_params.assert_called_once_with(
|
||||
graph_init_context_cls.assert_called_once_with(
|
||||
workflow_id="",
|
||||
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
|
||||
"node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}
|
||||
@ -697,7 +701,7 @@ class TestWorkflowEntryHelpers:
|
||||
call_depth=0,
|
||||
)
|
||||
dify_node_factory_cls.assert_called_once_with(
|
||||
graph_init_params=sentinel.graph_init_params,
|
||||
graph_init_context=sentinel.graph_init_context,
|
||||
graph_runtime_state=sentinel.graph_runtime_state,
|
||||
)
|
||||
mapping_user_inputs_to_variable_pool.assert_called_once_with(
|
||||
@ -734,11 +738,15 @@ class TestWorkflowEntryHelpers:
|
||||
patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables),
|
||||
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
|
||||
patch.object(workflow_entry, "add_variables_to_pool"),
|
||||
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
|
||||
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
|
||||
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
|
||||
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
|
||||
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
|
||||
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory),
|
||||
patch.object(
|
||||
workflow_entry.DifyNodeFactory,
|
||||
"from_graph_init_context",
|
||||
return_value=dify_node_factory,
|
||||
),
|
||||
patch.object(
|
||||
workflow_entry.WorkflowEntry,
|
||||
"mapping_user_inputs_to_variable_pool",
|
||||
|
||||
@ -6,23 +6,23 @@ MODULE = "services.plugin.plugin_auto_upgrade_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
|
||||
"""Patch session_factory.create_session() to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=False)
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.create_session.return_value = session
|
||||
patcher = patch(f"{MODULE}.session_factory", mock_factory)
|
||||
return patcher, session
|
||||
|
||||
|
||||
class TestGetStrategy:
|
||||
def test_returns_strategy_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
strategy = MagicMock()
|
||||
session.scalar.return_value = strategy
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.get_strategy("t1")
|
||||
@ -30,10 +30,10 @@ class TestGetStrategy:
|
||||
assert result is strategy
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.get_strategy("t1")
|
||||
@ -43,10 +43,10 @@ class TestGetStrategy:
|
||||
|
||||
class TestChangeStrategy:
|
||||
def test_creates_new_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
session.scalar.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
@ -63,11 +63,11 @@ class TestChangeStrategy:
|
||||
session.add.assert_called_once()
|
||||
|
||||
def test_updates_existing_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
with p1:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
@ -89,12 +89,11 @@ class TestChangeStrategy:
|
||||
|
||||
class TestExcludePlugin:
|
||||
def test_creates_default_strategy_when_none_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
session.scalar.return_value = None
|
||||
|
||||
with (
|
||||
p1,
|
||||
p2,
|
||||
patch(f"{MODULE}.select"),
|
||||
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
||||
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
||||
@ -110,13 +109,13 @@ class TestExcludePlugin:
|
||||
cs.assert_called_once()
|
||||
|
||||
def test_appends_to_exclude_list_in_exclude_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p-existing"]
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -128,13 +127,13 @@ class TestExcludePlugin:
|
||||
assert existing.exclude_plugins == ["p-existing", "p-new"]
|
||||
|
||||
def test_removes_from_include_list_in_partial_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "partial"
|
||||
existing.include_plugins = ["p1", "p2"]
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -146,12 +145,12 @@ class TestExcludePlugin:
|
||||
assert existing.include_plugins == ["p2"]
|
||||
|
||||
def test_switches_to_exclude_mode_from_all(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "all"
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
@ -164,13 +163,13 @@ class TestExcludePlugin:
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
|
||||
def test_no_duplicate_in_exclude_list(self):
|
||||
p1, p2, session = _patched_session()
|
||||
p1, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p1"]
|
||||
session.scalar.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
|
||||
@ -165,7 +165,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "test_key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load_once.return_value = b"test content"
|
||||
@ -178,7 +178,7 @@ class TestFileService:
|
||||
mock_storage.load_once.assert_called_once_with("test_key")
|
||||
|
||||
def test_get_file_base64_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_base64("non_existent")
|
||||
|
||||
@ -215,7 +215,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "pdf"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract:
|
||||
mock_extract.return_value = "Extracted text content"
|
||||
@ -227,7 +227,7 @@ class TestFileService:
|
||||
assert result == "Extracted text content"
|
||||
|
||||
def test_get_file_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_preview("non_existent")
|
||||
|
||||
@ -235,7 +235,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "exe"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
file_service.get_file_preview("file_id")
|
||||
|
||||
@ -246,7 +246,7 @@ class TestFileService:
|
||||
upload_file.extension = "jpg"
|
||||
upload_file.mime_type = "image/jpeg"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with (
|
||||
patch("services.file_service.file_helpers.verify_image_signature") as mock_verify,
|
||||
@ -269,7 +269,7 @@ class TestFileService:
|
||||
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_image_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
@ -279,7 +279,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "txt"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
@ -289,7 +289,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with (
|
||||
patch("services.file_service.file_helpers.verify_file_signature") as mock_verify,
|
||||
@ -309,7 +309,7 @@ class TestFileService:
|
||||
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
|
||||
|
||||
def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify:
|
||||
mock_verify.return_value = True
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
@ -321,7 +321,7 @@ class TestFileService:
|
||||
upload_file.extension = "png"
|
||||
upload_file.mime_type = "image/png"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"image content"
|
||||
@ -330,7 +330,7 @@ class TestFileService:
|
||||
assert mime == "image/png"
|
||||
|
||||
def test_get_public_image_preview_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found or signature is invalid"):
|
||||
file_service.get_public_image_preview("file_id")
|
||||
|
||||
@ -338,7 +338,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.extension = "txt"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
file_service.get_public_image_preview("file_id")
|
||||
|
||||
@ -346,7 +346,7 @@ class TestFileService:
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.id = "file_id"
|
||||
upload_file.key = "key"
|
||||
mock_db_session.query().where().first.return_value = upload_file
|
||||
mock_db_session.scalar.return_value = upload_file
|
||||
|
||||
with patch("services.file_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"hello world"
|
||||
@ -354,7 +354,7 @@ class TestFileService:
|
||||
assert result == "hello world"
|
||||
|
||||
def test_get_file_content_not_found(self, file_service, mock_db_session):
|
||||
mock_db_session.query().where().first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
file_service.get_file_content("file_id")
|
||||
|
||||
|
||||
@ -657,7 +657,7 @@ def _app(**kwargs: Any) -> App:
|
||||
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Arrange
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.return_value = _FakeQuery(None)
|
||||
fake_session.scalar.return_value = None
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -671,7 +671,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_foun
|
||||
# Arrange
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -686,7 +686,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_lim
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -701,7 +701,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -714,7 +714,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(m
|
||||
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
|
||||
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act / Assert
|
||||
@ -732,7 +732,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mod
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
@ -751,7 +751,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(mo
|
||||
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)]
|
||||
fake_session.scalar.side_effect = [webhook_trigger, workflow]
|
||||
_patch_session(monkeypatch, fake_session)
|
||||
|
||||
# Act
|
||||
|
||||
@ -2826,9 +2826,9 @@ class TestWorkflowServiceFreeNodeExecution:
|
||||
variable_pool = MagicMock()
|
||||
|
||||
with (
|
||||
patch("services.workflow_service.GraphInitParams") as mock_graph_init_params,
|
||||
patch("services.workflow_service.DifyGraphInitContext") as mock_graph_init_context_cls,
|
||||
patch("services.workflow_service.GraphRuntimeState"),
|
||||
patch("services.workflow_service.build_dify_run_context"),
|
||||
patch("services.workflow_service.build_dify_run_context") as mock_build_dify_run_context,
|
||||
patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls,
|
||||
patch("services.workflow_service.HumanInputNode") as mock_node_cls,
|
||||
):
|
||||
@ -2837,4 +2837,17 @@ class TestWorkflowServiceFreeNodeExecution:
|
||||
)
|
||||
assert node == mock_node_cls.return_value
|
||||
mock_node_cls.assert_called_once()
|
||||
mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context)
|
||||
mock_graph_init_context_cls.assert_called_once_with(
|
||||
workflow_id="wf-1",
|
||||
graph_config=workflow.graph_dict,
|
||||
run_context=mock_build_dify_run_context.return_value,
|
||||
call_depth=0,
|
||||
)
|
||||
mock_runtime_cls.assert_called_once_with(mock_build_dify_run_context.return_value)
|
||||
mock_node_cls.assert_called_once_with(
|
||||
id="n-1",
|
||||
config=node_config,
|
||||
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
|
||||
graph_runtime_state=ANY,
|
||||
runtime=mock_runtime_cls.return_value,
|
||||
)
|
||||
|
||||
@ -19,7 +19,57 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"),
|
||||
("[Example](http://example.com) some text", "[Example](http://example.com) some text"),
|
||||
# Leading symbols before markdown link are removed, including the opening bracket [
|
||||
("@[Test](https://example.com)", "Test](https://example.com)"),
|
||||
("@[Test](https://example.com)", "[Test](https://example.com)"),
|
||||
("~~标题~~", "标题~~"),
|
||||
('""quoted', "quoted"),
|
||||
("''test", "test"),
|
||||
("##话题", "话题"),
|
||||
("$$价格", "价格"),
|
||||
("%%百分比", "百分比"),
|
||||
("&&与逻辑", "与逻辑"),
|
||||
("((括号))", "括号))"),
|
||||
("**强调**", "强调**"),
|
||||
("++自增", "自增"),
|
||||
(",,逗号", "逗号"),
|
||||
("..省略", "省略"),
|
||||
("//注释", "注释"),
|
||||
("::范围", "范围"),
|
||||
(";;分号", "分号"),
|
||||
("<<左移", "左移"),
|
||||
("==等于", "等于"),
|
||||
(">>右移", "右移"),
|
||||
("??疑问", "疑问"),
|
||||
("@@提及", "提及"),
|
||||
("^^上标", "上标"),
|
||||
("__下划线", "下划线"),
|
||||
("``代码", "代码"),
|
||||
("~~删除线", "删除线"),
|
||||
(" 全角空格开头", "全角空格开头"),
|
||||
("、顿号开头", "顿号开头"),
|
||||
("。句号开头", "句号开头"),
|
||||
("「引号」测试", "引号」测试"),
|
||||
("『书名号』", "书名号』"),
|
||||
("【保留】测试", "【保留】测试"),
|
||||
("〖括号〗测试", "括号〗测试"),
|
||||
("〔括号〕测试", "括号〕测试"),
|
||||
("~~【保留】~~", "【保留】~~"),
|
||||
('"[公告]"', '[公告]"'),
|
||||
("[公告] 更新", "[公告] 更新"),
|
||||
("【通知】重要", "【通知】重要"),
|
||||
("[[嵌套]]", "[[嵌套]]"),
|
||||
("【【嵌套】】", "【【嵌套】】"),
|
||||
("[【混合】]", "[【混合】]"),
|
||||
("normal text", "normal text"),
|
||||
("123数字", "123数字"),
|
||||
("中文开头", "中文开头"),
|
||||
("alpha", "alpha"),
|
||||
("~", ""),
|
||||
("【", "【"),
|
||||
("[", "["),
|
||||
("~~~", ""),
|
||||
("【【【", "【【【"),
|
||||
("\t制表符", "\t制表符"),
|
||||
("\n换行", "\n换行"),
|
||||
],
|
||||
)
|
||||
def test_remove_leading_symbols(input_text, expected_output):
|
||||
|
||||
24
api/uv.lock
generated
24
api/uv.lock
generated
@ -1585,7 +1585,7 @@ dev = [
|
||||
{ name = "lxml-stubs", specifier = "~=0.5.1" },
|
||||
{ name = "mypy", specifier = "~=1.20.0" },
|
||||
{ name = "pandas-stubs", specifier = "~=3.0.0" },
|
||||
{ name = "pyrefly", specifier = ">=0.59.1" },
|
||||
{ name = "pyrefly", specifier = ">=0.60.0" },
|
||||
{ name = "pytest", specifier = "~=9.0.2" },
|
||||
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
|
||||
{ name = "pytest-cov", specifier = "~=7.1.0" },
|
||||
@ -4850,19 +4850,19 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pyrefly"
|
||||
version = "0.59.1"
|
||||
version = "0.60.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c6/c7/28d14b64888e2d03815627ebff8d57a9f08389c4bbebfe70ae1ed98a1267/pyrefly-0.60.0.tar.gz", hash = "sha256:2499f5b6ff5342e86dfe1cd94bcce133519bbbc93b7ad5636195fea4f0fa3b81", size = 5500389, upload-time = "2026-04-06T19:57:30.643Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/99/6c9984a09220e5eb7dd5c869b7a32d25c3d06b5e8854c6eb679db1145c3e/pyrefly-0.60.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bf1691af0fee69d0c99c3c6e9d26ab6acd3c8afef96416f9ba2e74934833b7b5", size = 12921262, upload-time = "2026-04-06T19:57:00.745Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/b3/6216aa3c00c88e59a27eb4149851b5affe86eeea6129f4224034a32dddb0/pyrefly-0.60.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3e71b70c9b95545cf3b479bc55d1381b531de7b2380eb64411088a1e56b634cb", size = 12424413, upload-time = "2026-04-06T19:57:03.417Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/87/eb8dd73abd92a93952ac27a605e463c432fb250fb23186574038c7035594/pyrefly-0.60.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:680ee5f8f98230ea145652d7344708f5375786209c5bf03d8b911fdb0d0d4195", size = 35940884, upload-time = "2026-04-06T19:57:06.909Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/34/dc6aeb67b840c745fcee6db358295d554abe6ab555a7eaaf44624bd80bf1/pyrefly-0.60.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d0b20dbbe4aff15b959e8d825b7521a144c4122c11e57022e83b36568c54470", size = 38677220, upload-time = "2026-04-06T19:57:11.235Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/6b/c863fcf7ef592b7d1db91502acf0d1113be8bed7a2a7143fc6f0dd90616f/pyrefly-0.60.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2911563c8e6b2eaefff68885c94727965469a35375a409235a7a4d2b7157dc15", size = 36907431, upload-time = "2026-04-06T19:57:15.074Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/a2/25ea095ab2ecca8e62884669b11a79f14299db93071685b73a97efbaf4f3/pyrefly-0.60.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0a631d9d04705e303fe156f2e62551611bc7ef8066c34708ceebcfb3088bd55", size = 41447898, upload-time = "2026-04-06T19:57:19.382Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/2c/097bdc6e8d40676b28eb03710a4577bc3c7b803cd24693ac02bf15de3d67/pyrefly-0.60.0-py3-none-win32.whl", hash = "sha256:a08d69298da5626cf502d3debbb6944fd13d2f405ea6625363751f1ff570d366", size = 11913434, upload-time = "2026-04-06T19:57:22.887Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/d4/8d27fe310e830c8d11ab73db38b93f9fd2e218744b6efb1204401c9a74d5/pyrefly-0.60.0-py3-none-win_amd64.whl", hash = "sha256:56cf30654e708ae1dd635ffefcba4fa4b349dd7004a6ccc5c41e3a9bb944320c", size = 12745033, upload-time = "2026-04-06T19:57:25.517Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/ad/8eea1f8fb8209f91f6dbfe48000c9d05fd0cdb1b5b3157283c9b1dada55d/pyrefly-0.60.0-py3-none-win_arm64.whl", hash = "sha256:b6d27fba970f4777063c0227c54167d83bece1804ea34f69e7118e409ba038d2", size = 12246390, upload-time = "2026-04-06T19:57:28.141Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
13
pnpm-lock.yaml
generated
13
pnpm-lock.yaml
generated
@ -520,8 +520,8 @@ catalogs:
|
||||
specifier: 13.0.0
|
||||
version: 13.0.0
|
||||
vinext:
|
||||
specifier: https://pkg.pr.new/vinext@adbf24d
|
||||
version: 0.0.5
|
||||
specifier: 0.0.41
|
||||
version: 0.0.41
|
||||
vite-plugin-inspect:
|
||||
specifier: 12.0.0-beta.1
|
||||
version: 12.0.0-beta.1
|
||||
@ -1162,7 +1162,7 @@ importers:
|
||||
version: 3.19.3
|
||||
vinext:
|
||||
specifier: 'catalog:'
|
||||
version: https://pkg.pr.new/vinext@adbf24d(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2)
|
||||
version: 0.0.41(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2)
|
||||
vite:
|
||||
specifier: npm:@voidzero-dev/vite-plus-core@0.1.16
|
||||
version: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)'
|
||||
@ -8336,9 +8336,8 @@ packages:
|
||||
vfile@6.0.3:
|
||||
resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==}
|
||||
|
||||
vinext@https://pkg.pr.new/vinext@adbf24d:
|
||||
resolution: {tarball: https://pkg.pr.new/vinext@adbf24d}
|
||||
version: 0.0.5
|
||||
vinext@0.0.41:
|
||||
resolution: {integrity: sha512-fpQjNp6cIqjYGH2/kbhN2SdIYHEu79RdlII23SWsY1Qp7LM+je8GfTJH1sxw6dASxPhZKZB/jCmTm5d2/D25zw==}
|
||||
engines: {node: '>=22'}
|
||||
hasBin: true
|
||||
peerDependencies:
|
||||
@ -16586,7 +16585,7 @@ snapshots:
|
||||
'@types/unist': 3.0.3
|
||||
vfile-message: 4.0.3
|
||||
|
||||
vinext@https://pkg.pr.new/vinext@adbf24d(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2):
|
||||
vinext@0.0.41(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2):
|
||||
dependencies:
|
||||
'@unpic/react': 1.0.2(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
|
||||
'@vercel/og': 0.8.6
|
||||
|
||||
@ -221,7 +221,7 @@ catalog:
|
||||
unist-util-visit: 5.1.0
|
||||
use-context-selector: 2.0.0
|
||||
uuid: 13.0.0
|
||||
vinext: https://pkg.pr.new/vinext@adbf24d
|
||||
vinext: 0.0.41
|
||||
vite: npm:@voidzero-dev/vite-plus-core@0.1.16
|
||||
vite-plugin-inspect: 12.0.0-beta.1
|
||||
vite-plus: 0.1.16
|
||||
|
||||
@ -229,6 +229,23 @@ describe('output-schema-utils', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Dify compact types (workflow-as-tool output_schema)', () => {
|
||||
it('should resolve array[string] to arrayString (issue #34428)', () => {
|
||||
const result = resolveVarType({ type: 'array[string]' })
|
||||
expect(result.type).toBe(VarType.arrayString)
|
||||
})
|
||||
|
||||
it('should resolve Array[string] case-insensitively', () => {
|
||||
const result = resolveVarType({ type: 'Array[string]' })
|
||||
expect(result.type).toBe(VarType.arrayString)
|
||||
})
|
||||
|
||||
it('should resolve array[object] to arrayObject', () => {
|
||||
const result = resolveVarType({ type: 'array[object]' })
|
||||
expect(result.type).toBe(VarType.arrayObject)
|
||||
})
|
||||
})
|
||||
|
||||
describe('unknown types', () => {
|
||||
it('should resolve unknown type to any', () => {
|
||||
const result = resolveVarType({ type: 'unknown_type' })
|
||||
|
||||
@ -2,6 +2,30 @@ import type { SchemaTypeDefinition } from '@/service/use-common'
|
||||
import { VarType } from '@/app/components/workflow/types'
|
||||
import { getMatchedSchemaType } from '../_base/components/variable/use-match-schema-type'
|
||||
|
||||
/**
|
||||
* Workflow-as-tool and some internal APIs store Dify VarType strings (e.g. `array[string]`)
|
||||
* in JSON Schema `type` instead of standard `{ type: 'array', items: { type: 'string' } }`.
|
||||
* Map those compact strings to VarType so downstream (e.g. Code node var picker) does not
|
||||
* fall back to `any` and get filtered out.
|
||||
*/
|
||||
const resolveDifyCompactTypeString = (typeStr: string): VarType | undefined => {
|
||||
const trimmed = typeStr.trim()
|
||||
const m = /^array\[(string|number|integer|boolean|object|file|any)\]$/i.exec(trimmed)
|
||||
if (!m)
|
||||
return undefined
|
||||
const inner = m[1].toLowerCase()
|
||||
const map: Record<string, VarType> = {
|
||||
string: VarType.arrayString,
|
||||
number: VarType.arrayNumber,
|
||||
integer: VarType.arrayNumber,
|
||||
boolean: VarType.arrayBoolean,
|
||||
object: VarType.arrayObject,
|
||||
file: VarType.arrayFile,
|
||||
any: VarType.arrayAny,
|
||||
}
|
||||
return map[inner]
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes a JSON Schema type to a simple string type.
|
||||
* Handles complex schemas with oneOf, anyOf, allOf.
|
||||
@ -54,6 +78,12 @@ export const resolveVarType = (
|
||||
schemaTypeDefinitions?: SchemaTypeDefinition[],
|
||||
): { type: VarType, schemaType?: string } => {
|
||||
const schemaType = getMatchedSchemaType(schema, schemaTypeDefinitions)
|
||||
if (schema && typeof schema.type === 'string') {
|
||||
const compact = resolveDifyCompactTypeString(schema.type)
|
||||
if (compact !== undefined)
|
||||
return { type: compact, schemaType }
|
||||
}
|
||||
|
||||
const normalizedType = normalizeJsonSchemaType(schema)
|
||||
|
||||
switch (normalizedType) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user