diff --git a/api/templates/without-brand/invite_member_mail_template_zh-CN.html b/api/templates/without-brand/invite_member_mail_template_zh-CN.html
index d4f80c66f8..e787c90914 100644
--- a/api/templates/without-brand/invite_member_mail_template_zh-CN.html
+++ b/api/templates/without-brand/invite_member_mail_template_zh-CN.html
@@ -1,69 +1,90 @@
-
+
-
-
-
尊敬的 {{ to }},
-
{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。
-
点击下方按钮即可登录 {{application_title}} 并且加入空间。
-
在此登录
-
-
+
+
+
+
尊敬的 {{ to }},
+
{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。
+
点击下方按钮即可登录 {{application_title}} 并且加入空间。
+
在此登录
+
此致,
+
{{application_title}} 团队
+
diff --git a/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html
new file mode 100644
index 0000000000..a5758a2184
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_en-US.html
@@ -0,0 +1,89 @@
+
+
+
+
+
+
+
+
+
+
+
You are now the owner of {{WorkspaceName}}
+
+
You have been assigned as the new owner of the workspace "{{WorkspaceName}}".
+
As the new owner, you now have full administrative privileges for this workspace.
+
If you have any questions, please contact support@dify.ai.
+
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html
new file mode 100644
index 0000000000..53bab92552
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html
@@ -0,0 +1,89 @@
+
+
+
+
+
+
+
+
+
+
+
您现在是 {{WorkspaceName}} 的所有者
+
+
您已被分配为工作空间“{{WorkspaceName}}”的新所有者。
+
作为新所有者,您现在对该工作空间拥有完全的管理权限。
+
如果您有任何问题,请联系support@dify.ai。
+
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html
new file mode 100644
index 0000000000..3e7faeb01e
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_en-US.html
@@ -0,0 +1,119 @@
+
+
+
+
+
+
+
+
+
+
+
Workspace ownership has been transferred
+
+
You have successfully transferred ownership of the workspace "{{WorkspaceName}}" to {{NewOwnerEmail}}.
+
You no longer have owner privileges for this workspace. Your access level has been changed to Admin.
+
If you did not initiate this transfer or have concerns about this change, please contact support@dify.ai immediately.
+
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html
new file mode 100644
index 0000000000..31e3c23140
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html
@@ -0,0 +1,119 @@
+
+
+
+
+
+
+
+
+
+
+
工作区所有权已转移
+
+
您已成功将工作空间“{{WorkspaceName}}”的所有权转移给{{NewOwnerEmail}}。
+
您不再拥有此工作空间的拥有者权限。您的访问级别已更改为管理员。
+
如果您没有发起此转移或对此变更有任何疑问,请立即联系support@dify.ai。
+
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_owner_confirm_template_en-US.html b/api/templates/without-brand/transfer_workspace_owner_confirm_template_en-US.html
new file mode 100644
index 0000000000..11ce275641
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_owner_confirm_template_en-US.html
@@ -0,0 +1,150 @@
+
+
+
+
+
+
+
+
+
+
+
Verify Your Request to Transfer Workspace Ownership
+
+
We received a request to transfer ownership of your workspace “{{WorkspaceName}}”.
+
To confirm this action, please use the verification code below.
+
This code will only be valid for the next 5 minutes:
+
+
+ {{code}}
+
+
Please note:
+
+ - The ownership transfer will take effect immediately once confirmed and cannot be undone.
+ - You’ll become an admin member, and the new owner will have full control of the workspace.
+
+
If you didn’t make this request, please ignore this email or contact support immediately.
+
+
+
+
+
diff --git a/api/templates/without-brand/transfer_workspace_owner_confirm_template_zh-CN.html b/api/templates/without-brand/transfer_workspace_owner_confirm_template_zh-CN.html
new file mode 100644
index 0000000000..36b9a24a89
--- /dev/null
+++ b/api/templates/without-brand/transfer_workspace_owner_confirm_template_zh-CN.html
@@ -0,0 +1,150 @@
+
+
+
+
+
+
+
+
+
+
+
验证您的工作空间所有权转移请求
+
+
我们收到了将您的工作空间“{{WorkspaceName}}”的所有权转移的请求。
+
为了确认此操作,请使用以下验证码。
+
此验证码仅在5分钟内有效:
+
+
+ {{code}}
+
+
请注意:
+
+ - 所有权转移一旦确认将立即生效且无法撤销。
+ - 您将成为管理员成员,新的所有者将拥有工作空间的完全控制权。
+
+
如果您没有发起此请求,请忽略此邮件或立即联系客服。
+
+
+
+
+
diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example
index 4046096c27..2e98dec964 100644
--- a/api/tests/integration_tests/.env.example
+++ b/api/tests/integration_tests/.env.example
@@ -203,6 +203,8 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
+CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
+OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false
diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py
index 30cd2e60cb..e96d70c4a9 100644
--- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py
+++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py
@@ -214,7 +214,7 @@ class TestDraftVariableLoader(unittest.TestCase):
def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session:
- session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.app_id == self._test_app_id).delete(
+ session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete(
synchronize_session=False
)
session.commit()
diff --git a/api/tests/integration_tests/vdb/couchbase/test_couchbase.py b/api/tests/integration_tests/vdb/couchbase/test_couchbase.py
index d76c34ba0e..eef1ee4e75 100644
--- a/api/tests/integration_tests/vdb/couchbase/test_couchbase.py
+++ b/api/tests/integration_tests/vdb/couchbase/test_couchbase.py
@@ -4,7 +4,6 @@ import time
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
- get_example_text,
setup_mock_redis,
)
diff --git a/api/tests/integration_tests/vdb/matrixone/test_matrixone.py b/api/tests/integration_tests/vdb/matrixone/test_matrixone.py
index c8b19ef3ad..c4056db63e 100644
--- a/api/tests/integration_tests/vdb/matrixone/test_matrixone.py
+++ b/api/tests/integration_tests/vdb/matrixone/test_matrixone.py
@@ -1,7 +1,6 @@
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
- get_example_text,
setup_mock_redis,
)
diff --git a/api/tests/integration_tests/vdb/opengauss/test_opengauss.py b/api/tests/integration_tests/vdb/opengauss/test_opengauss.py
index f2013848bf..2a1129493c 100644
--- a/api/tests/integration_tests/vdb/opengauss/test_opengauss.py
+++ b/api/tests/integration_tests/vdb/opengauss/test_opengauss.py
@@ -5,7 +5,6 @@ import psycopg2 # type: ignore
from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
- get_example_text,
setup_mock_redis,
)
diff --git a/api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py b/api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py
index 3d7873442b..02931fef5a 100644
--- a/api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py
+++ b/api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py
@@ -1,7 +1,6 @@
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
- get_example_text,
setup_mock_redis,
)
diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py
index da890d0b7c..da549af1b6 100644
--- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py
+++ b/api/tests/integration_tests/vdb/tablestore/test_tablestore.py
@@ -1,4 +1,7 @@
import os
+import uuid
+
+import tablestore
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
TableStoreConfig,
@@ -6,6 +9,8 @@ from core.rag.datasource.vdb.tablestore.tablestore_vector import (
)
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
+ get_example_document,
+ get_example_text,
setup_mock_redis,
)
@@ -29,6 +34,49 @@ class TableStoreVectorTest(AbstractVectorTest):
assert len(ids) == 1
assert ids[0] == self.example_doc_id
+ def create_vector(self):
+ self.vector.create(
+ texts=[get_example_document(doc_id=self.example_doc_id)],
+ embeddings=[self.example_embedding],
+ )
+ while True:
+ search_response = self.vector._tablestore_client.search(
+ table_name=self.vector._table_name,
+ index_name=self.vector._index_name,
+ search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
+ columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
+ )
+ if search_response.total_count == 1:
+ break
+
+ def search_by_vector(self):
+ super().search_by_vector()
+ docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
+ assert len(docs) == 1
+ assert docs[0].metadata["doc_id"] == self.example_doc_id
+ assert docs[0].metadata["score"] > 0
+
+ docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
+ assert len(docs) == 0
+
+ def search_by_full_text(self):
+ super().search_by_full_text()
+ docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
+ assert len(docs) == 1
+ assert docs[0].metadata["doc_id"] == self.example_doc_id
+ assert not hasattr(docs[0], "score")
+
+ docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
+ assert len(docs) == 0
+
+ def run_all_tests(self):
+ try:
+ self.vector.delete()
+ except Exception:
+ pass
+
+ return super().run_all_tests()
+
def test_tablestore_vector(setup_mock_redis):
TableStoreVectorTest().run_all_tests()
diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py
index 7c48d84d69..330ebfd54a 100644
--- a/api/tests/integration_tests/workflow/nodes/__mock/model.py
+++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py
@@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
mode: str,
credentials: dict,
):
- model_provider_factory = ModelProviderFactory(tenant_id="test_tenant")
+ model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py
index 13d78c2d83..707b28e6d8 100644
--- a/api/tests/integration_tests/workflow/nodes/test_code.py
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py
@@ -9,12 +9,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@@ -50,7 +50,7 @@ def init_code_node(code_config: dict):
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -66,6 +66,10 @@ def init_code_node(code_config: dict):
config=code_config,
)
+ # Initialize node data
+ if "data" in code_config:
+ node.init_node_data(code_config["data"])
+
return node
@@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
- node.node_data = cast(CodeNodeData, node.node_data)
+ node._node_data = cast(CodeNodeData, node._node_data)
# validate
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list():
@@ -330,10 +334,10 @@ def test_execute_code_output_object_list():
]
}
- node.node_data = cast(CodeNodeData, node.node_data)
+ node._node_data = cast(CodeNodeData, node._node_data)
# validate
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -353,4 +357,36 @@ def test_execute_code_output_object_list():
# validate
with pytest.raises(ValueError):
- node._transform_result(result, node.node_data.outputs)
+ node._transform_result(result, node._node_data.outputs)
+
+
+def test_execute_code_scientific_notation():
+ code = """
+ def main() -> dict:
+ return {
+ "result": -8.0E-5
+ }
+ """
+ code = "\n".join([line[4:] for line in code.split("\n")])
+
+ code_config = {
+ "id": "code",
+ "data": {
+ "outputs": {
+ "result": {
+ "type": "number",
+ },
+ },
+ "title": "123",
+ "variables": [],
+ "answer": "123",
+ "code_language": "python3",
+ "code": code,
+ },
+ }
+
+ node = init_code_node(code_config)
+ # execute node
+ result = node._run()
+ assert isinstance(result, NodeRunResult)
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py
index 1ab0cc2451..d7856129a3 100644
--- a/api/tests/integration_tests/workflow/nodes/test_http.py
+++ b/api/tests/integration_tests/workflow/nodes/test_http.py
@@ -6,11 +6,11 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.http_request.node import HttpRequestNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
@@ -44,7 +44,7 @@ def init_http_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -52,7 +52,7 @@ def init_http_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
- return HttpRequestNode(
+ node = HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
@@ -60,6 +60,12 @@ def init_http_node(config: dict):
config=config,
)
+ # Initialize node data
+ if "data" in config:
+ node.init_node_data(config["data"])
+
+ return node
+
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_get(setup_http_mock):
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index 389d1071f3..a14791bc67 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -1,32 +1,24 @@
import json
-import os
import time
import uuid
from collections.abc import Generator
-from decimal import Decimal
from unittest.mock import MagicMock, patch
-import pytest
-
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
-from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
-from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
+from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
-from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
-from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
def init_llm_node(config: dict) -> LLMNode:
@@ -63,12 +55,14 @@ def init_llm_node(config: dict) -> LLMNode:
# construct variable pool
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "what's the weather today?",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ app_id=app_id,
+ workflow_id=workflow_id,
+ files=[],
+ query="what's the weather today?",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -83,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode:
config=config,
)
+ # Initialize node data
+ if "data" in config:
+ node.init_node_data(config["data"])
+
return node
-def test_execute_llm(flask_req_ctx):
+def test_execute_llm():
node = init_llm_node(
config={
"id": "llm",
@@ -94,7 +92,7 @@ def test_execute_llm(flask_req_ctx):
"title": "123",
"type": "llm",
"model": {
- "provider": "langgenius/openai/openai",
+ "provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
@@ -113,55 +111,62 @@ def test_execute_llm(flask_req_ctx):
},
)
- credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
+ db.session.close = MagicMock()
- # Create a proper LLM result with real entities
- mock_usage = LLMUsage(
- prompt_tokens=30,
- prompt_unit_price=Decimal("0.001"),
- prompt_price_unit=Decimal("1000"),
- prompt_price=Decimal("0.00003"),
- completion_tokens=20,
- completion_unit_price=Decimal("0.002"),
- completion_price_unit=Decimal("1000"),
- completion_price=Decimal("0.00004"),
- total_tokens=50,
- total_price=Decimal("0.00007"),
- currency="USD",
- latency=0.5,
- )
+ # Mock the _fetch_model_config to avoid database calls
+ def mock_fetch_model_config(**_kwargs):
+ from decimal import Decimal
+ from unittest.mock import MagicMock
- mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
+ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+ from core.model_runtime.entities.message_entities import AssistantPromptMessage
- mock_llm_result = LLMResult(
- model="gpt-3.5-turbo",
- prompt_messages=[],
- message=mock_message,
- usage=mock_usage,
- )
+ # Create mock model instance
+ mock_model_instance = MagicMock()
+ mock_usage = LLMUsage(
+ prompt_tokens=30,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal(1000),
+ prompt_price=Decimal("0.00003"),
+ completion_tokens=20,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal(1000),
+ completion_price=Decimal("0.00004"),
+ total_tokens=50,
+ total_price=Decimal("0.00007"),
+ currency="USD",
+ latency=0.5,
+ )
+ mock_message = AssistantPromptMessage(content="Test response from mock")
+ mock_llm_result = LLMResult(
+ model="gpt-3.5-turbo",
+ prompt_messages=[],
+ message=mock_message,
+ usage=mock_usage,
+ )
+ mock_model_instance.invoke_llm.return_value = mock_llm_result
- # Create a simple mock model instance that doesn't call real providers
- mock_model_instance = MagicMock()
- mock_model_instance.invoke_llm.return_value = mock_llm_result
+ # Create mock model config
+ mock_model_config = MagicMock()
+ mock_model_config.mode = "chat"
+ mock_model_config.provider = "openai"
+ mock_model_config.model = "gpt-3.5-turbo"
+ mock_model_config.parameters = {}
- # Create a simple mock model config with required attributes
- mock_model_config = MagicMock()
- mock_model_config.mode = "chat"
- mock_model_config.provider = "langgenius/openai/openai"
- mock_model_config.model = "gpt-3.5-turbo"
- mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
-
- # Mock the _fetch_model_config method
- def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
- # Also mock ModelManager.get_model_instance to avoid database calls
- def mock_get_model_instance(_self, **kwargs):
- return mock_model_instance
+ # Mock fetch_prompt_messages to avoid database calls
+ def mock_fetch_prompt_messages_1(**_kwargs):
+ from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
+
+ return [
+ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
+ UserPromptMessage(content="what's the weather today?"),
+ ], []
with (
- patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
- patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
+ patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
+ patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
# execute node
result = node._run()
@@ -169,6 +174,9 @@ def test_execute_llm(flask_req_ctx):
for item in result:
if isinstance(item, RunCompletedEvent):
+ if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
+ print(f"Error: {item.run_result.error}")
+ print(f"Error type: {item.run_result.error_type}")
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
@@ -176,8 +184,7 @@ def test_execute_llm(flask_req_ctx):
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
-@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
-def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
+def test_execute_llm_with_jinja2():
"""
Test execute LLM node with jinja2
"""
@@ -218,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
# Mock db.session.close()
db.session.close = MagicMock()
- # Create a proper LLM result with real entities
- mock_usage = LLMUsage(
- prompt_tokens=30,
- prompt_unit_price=Decimal("0.001"),
- prompt_price_unit=Decimal("1000"),
- prompt_price=Decimal("0.00003"),
- completion_tokens=20,
- completion_unit_price=Decimal("0.002"),
- completion_price_unit=Decimal("1000"),
- completion_price=Decimal("0.00004"),
- total_tokens=50,
- total_price=Decimal("0.00007"),
- currency="USD",
- latency=0.5,
- )
-
- mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
-
- mock_llm_result = LLMResult(
- model="gpt-3.5-turbo",
- prompt_messages=[],
- message=mock_message,
- usage=mock_usage,
- )
-
- # Create a simple mock model instance that doesn't call real providers
- mock_model_instance = MagicMock()
- mock_model_instance.invoke_llm.return_value = mock_llm_result
-
- # Create a simple mock model config with required attributes
- mock_model_config = MagicMock()
- mock_model_config.mode = "chat"
- mock_model_config.provider = "openai"
- mock_model_config.model = "gpt-3.5-turbo"
- mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
-
# Mock the _fetch_model_config method
- def mock_fetch_model_config_func(_node_data_model):
+ def mock_fetch_model_config(**_kwargs):
+ from decimal import Decimal
+ from unittest.mock import MagicMock
+
+ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+ from core.model_runtime.entities.message_entities import AssistantPromptMessage
+
+ # Create mock model instance
+ mock_model_instance = MagicMock()
+ mock_usage = LLMUsage(
+ prompt_tokens=30,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal(1000),
+ prompt_price=Decimal("0.00003"),
+ completion_tokens=20,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal(1000),
+ completion_price=Decimal("0.00004"),
+ total_tokens=50,
+ total_price=Decimal("0.00007"),
+ currency="USD",
+ latency=0.5,
+ )
+ mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
+ mock_llm_result = LLMResult(
+ model="gpt-3.5-turbo",
+ prompt_messages=[],
+ message=mock_message,
+ usage=mock_usage,
+ )
+ mock_model_instance.invoke_llm.return_value = mock_llm_result
+
+ # Create mock model config
+ mock_model_config = MagicMock()
+ mock_model_config.mode = "chat"
+ mock_model_config.provider = "openai"
+ mock_model_config.model = "gpt-3.5-turbo"
+ mock_model_config.parameters = {}
+
return mock_model_instance, mock_model_config
- # Also mock ModelManager.get_model_instance to avoid database calls
- def mock_get_model_instance(_self, **kwargs):
- return mock_model_instance
+ # Mock fetch_prompt_messages to avoid database calls
+ def mock_fetch_prompt_messages_2(**_kwargs):
+ from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
+
+ return [
+ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
+ UserPromptMessage(content="what's the weather today?"),
+ ], []
with (
- patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
- patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
+ patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
+ patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
):
# execute node
result = node._run()
diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
index 0df8e8b146..edd70193a8 100644
--- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
+++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
@@ -8,11 +8,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
+from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
@@ -64,12 +64,9 @@ def init_parameter_extractor_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "what's the weather in SF",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa"
+ ),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -77,13 +74,15 @@ def init_parameter_extractor_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
- return ParameterExtractorNode(
+ node = ParameterExtractorNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
+ node.init_node_data(config.get("data", {}))
+ return node
def test_function_calling_parameter_extractor(setup_model_mock):
diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
index a5f2677a59..f71a5ee140 100644
--- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py
+++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py
@@ -6,11 +6,11 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@@ -61,7 +61,7 @@ def test_execute_code(setup_code_executor_mock):
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock):
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
+ node.init_node_data(config.get("data", {}))
# execute node
result = node._run()
diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py
index 039beedafe..8476c1f874 100644
--- a/api/tests/integration_tests/workflow/nodes/test_tool.py
+++ b/api/tests/integration_tests/workflow/nodes/test_tool.py
@@ -6,12 +6,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -44,19 +44,21 @@ def init_tool_node(config: dict):
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
- return ToolNode(
+ node = ToolNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
+ node.init_node_data(config.get("data", {}))
+ return node
def test_tool_variable_invoke():
diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py
index cac0a688cd..e9d4ee1935 100644
--- a/api/tests/unit_tests/configs/test_dify_config.py
+++ b/api/tests/unit_tests/configs/test_dify_config.py
@@ -1,6 +1,7 @@
import os
from flask import Flask
+from packaging.version import Version
from yarl import URL
from configs.app_config import DifyConfig
@@ -40,6 +41,9 @@ def test_dify_config(monkeypatch):
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
+ # values from pyproject.toml
+ assert Version(config.project.version) >= Version("1.0.0")
+
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
@@ -84,6 +88,7 @@ def test_flask_configs(monkeypatch):
"pool_pre_ping": False,
"pool_recycle": 3600,
"pool_size": 30,
+ "pool_use_lifo": False,
}
assert config["CONSOLE_WEB_URL"] == "https://example.com"
diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py
index 077ffe3408..f484fb22d3 100644
--- a/api/tests/unit_tests/conftest.py
+++ b/api/tests/unit_tests/conftest.py
@@ -26,8 +26,15 @@ redis_mock.hgetall = MagicMock(return_value={})
redis_mock.hdel = MagicMock()
redis_mock.incr = MagicMock(return_value=1)
+# Add the API directory to Python path to ensure proper imports
+import sys
+
+sys.path.insert(0, PROJECT_DIR)
+
# apply the mock to the Redis client in the Flask app
-redis_patcher = patch("extensions.ext_redis.redis_client", redis_mock)
+from extensions import ext_redis
+
+redis_patcher = patch.object(ext_redis, "redis_client", redis_mock)
redis_patcher.start()
diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_oauth.py
new file mode 100644
index 0000000000..037c9f2745
--- /dev/null
+++ b/api/tests/unit_tests/controllers/console/auth/test_oauth.py
@@ -0,0 +1,496 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from flask import Flask
+
+from controllers.console.auth.oauth import (
+ OAuthCallback,
+ OAuthLogin,
+ _generate_account,
+ _get_account_by_openid_or_email,
+ get_oauth_providers,
+)
+from libs.oauth import OAuthUserInfo
+from models.account import AccountStatus
+from services.errors.account import AccountNotFoundError
+
+
+class TestGetOAuthProviders:
+ @pytest.fixture
+ def app(self):
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+ return app
+
+ @pytest.mark.parametrize(
+ ("github_config", "google_config", "expected_github", "expected_google"),
+ [
+ # Both providers configured
+ (
+ {"id": "github_id", "secret": "github_secret"},
+ {"id": "google_id", "secret": "google_secret"},
+ True,
+ True,
+ ),
+ # Only GitHub configured
+ ({"id": "github_id", "secret": "github_secret"}, {"id": None, "secret": None}, True, False),
+ # Only Google configured
+ ({"id": None, "secret": None}, {"id": "google_id", "secret": "google_secret"}, False, True),
+ # No providers configured
+ ({"id": None, "secret": None}, {"id": None, "secret": None}, False, False),
+ ],
+ )
+ @patch("controllers.console.auth.oauth.dify_config")
+ def test_should_configure_oauth_providers_correctly(
+ self, mock_config, app, github_config, google_config, expected_github, expected_google
+ ):
+ mock_config.GITHUB_CLIENT_ID = github_config["id"]
+ mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
+ mock_config.GOOGLE_CLIENT_ID = google_config["id"]
+ mock_config.GOOGLE_CLIENT_SECRET = google_config["secret"]
+ mock_config.CONSOLE_API_URL = "http://localhost"
+
+ with app.app_context():
+ providers = get_oauth_providers()
+
+ assert (providers["github"] is not None) == expected_github
+ assert (providers["google"] is not None) == expected_google
+
+
+class TestOAuthLogin:
+ @pytest.fixture
+ def resource(self):
+ return OAuthLogin()
+
+ @pytest.fixture
+ def app(self):
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+ return app
+
+ @pytest.fixture
+ def mock_oauth_provider(self):
+ provider = MagicMock()
+ provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?..."
+ return provider
+
+ @pytest.mark.parametrize(
+ ("invite_token", "expected_token"),
+ [
+ (None, None),
+ ("test_invite_token", "test_invite_token"),
+ ("", None),
+ ],
+ )
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth.redirect")
+ def test_should_handle_oauth_login_with_various_tokens(
+ self,
+ mock_redirect,
+ mock_get_providers,
+ resource,
+ app,
+ mock_oauth_provider,
+ invite_token,
+ expected_token,
+ ):
+ mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
+
+ query_string = f"invite_token={invite_token}" if invite_token else ""
+ with app.test_request_context(f"/auth/oauth/github?{query_string}"):
+ resource.get("github")
+
+ mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
+ mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
+
+ @pytest.mark.parametrize(
+ ("provider", "expected_error"),
+ [
+ ("invalid_provider", "Invalid provider"),
+ ("github", "Invalid provider"), # When GitHub is not configured
+ ("google", "Invalid provider"), # When Google is not configured
+ ],
+ )
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ def test_should_return_error_for_invalid_providers(
+ self, mock_get_providers, resource, app, provider, expected_error
+ ):
+ mock_get_providers.return_value = {"github": None, "google": None}
+
+ with app.test_request_context(f"/auth/oauth/{provider}"):
+ response, status_code = resource.get(provider)
+
+ assert status_code == 400
+ assert response["error"] == expected_error
+
+
+class TestOAuthCallback:
+ @pytest.fixture
+ def resource(self):
+ return OAuthCallback()
+
+ @pytest.fixture
+ def app(self):
+ app = Flask(__name__)
+ app.config["TESTING"] = True
+ return app
+
+ @pytest.fixture
+ def oauth_setup(self):
+ """Common OAuth setup for callback tests"""
+ oauth_provider = MagicMock()
+ oauth_provider.get_access_token.return_value = "access_token"
+ oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
+
+ account = MagicMock()
+ account.status = AccountStatus.ACTIVE.value
+
+ token_pair = MagicMock()
+ token_pair.access_token = "jwt_access_token"
+ token_pair.refresh_token = "jwt_refresh_token"
+
+ return {"provider": oauth_provider, "account": account, "token_pair": token_pair}
+
+ @patch("controllers.console.auth.oauth.dify_config")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth._generate_account")
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.redirect")
+ def test_should_handle_successful_oauth_callback(
+ self,
+ mock_redirect,
+ mock_tenant_service,
+ mock_account_service,
+ mock_generate_account,
+ mock_get_providers,
+ mock_config,
+ resource,
+ app,
+ oauth_setup,
+ ):
+ mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
+ mock_get_providers.return_value = {"github": oauth_setup["provider"]}
+ mock_generate_account.return_value = oauth_setup["account"]
+ mock_account_service.login.return_value = oauth_setup["token_pair"]
+
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ resource.get("github")
+
+ oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
+ oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
+ mock_redirect.assert_called_once_with(
+ "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
+ )
+
+ @pytest.mark.parametrize(
+ ("exception", "expected_error"),
+ [
+ (Exception("OAuth error"), "OAuth process failed"),
+ (ValueError("Invalid token"), "OAuth process failed"),
+ (KeyError("Missing key"), "OAuth process failed"),
+ ],
+ )
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ def test_should_handle_oauth_exceptions(
+ self, mock_get_providers, mock_db, resource, app, exception, expected_error
+ ):
+ # Mock database session
+ mock_db.session = MagicMock()
+ mock_db.session.rollback = MagicMock()
+
+ # Import the real requests module to create a proper exception
+ import requests
+
+ request_exception = requests.exceptions.RequestException("OAuth error")
+ request_exception.response = MagicMock()
+ request_exception.response.text = str(exception)
+
+ mock_oauth_provider = MagicMock()
+ mock_oauth_provider.get_access_token.side_effect = request_exception
+ mock_get_providers.return_value = {"github": mock_oauth_provider}
+
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ response, status_code = resource.get("github")
+
+ assert status_code == 400
+ assert response["error"] == expected_error
+
+ @pytest.mark.parametrize(
+ ("account_status", "expected_redirect"),
+ [
+ (AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
+ # CLOSED status: Currently NOT handled, will proceed to login (security issue)
+ # This documents actual behavior. See test_defensive_check_for_closed_account_status for details
+ (
+ AccountStatus.CLOSED.value,
+ "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
+ ),
+ ],
+ )
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.dify_config")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth._generate_account")
+ @patch("controllers.console.auth.oauth.redirect")
+ def test_should_redirect_based_on_account_status(
+ self,
+ mock_redirect,
+ mock_generate_account,
+ mock_get_providers,
+ mock_config,
+ mock_db,
+ mock_tenant_service,
+ mock_account_service,
+ resource,
+ app,
+ oauth_setup,
+ account_status,
+ expected_redirect,
+ ):
+ # Mock database session
+ mock_db.session = MagicMock()
+ mock_db.session.rollback = MagicMock()
+ mock_db.session.commit = MagicMock()
+
+ mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
+ mock_get_providers.return_value = {"github": oauth_setup["provider"]}
+
+ account = MagicMock()
+ account.status = account_status
+ account.id = "123"
+ mock_generate_account.return_value = account
+
+ # Mock login for CLOSED status
+ mock_token_pair = MagicMock()
+ mock_token_pair.access_token = "jwt_access_token"
+ mock_token_pair.refresh_token = "jwt_refresh_token"
+ mock_account_service.login.return_value = mock_token_pair
+
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ resource.get("github")
+
+ mock_redirect.assert_called_once_with(expected_redirect)
+
+ @patch("controllers.console.auth.oauth.dify_config")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth._generate_account")
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.AccountService")
+ def test_should_activate_pending_account(
+ self,
+ mock_account_service,
+ mock_tenant_service,
+ mock_db,
+ mock_generate_account,
+ mock_get_providers,
+ mock_config,
+ resource,
+ app,
+ oauth_setup,
+ ):
+ mock_get_providers.return_value = {"github": oauth_setup["provider"]}
+
+ mock_account = MagicMock()
+ mock_account.status = AccountStatus.PENDING.value
+ mock_generate_account.return_value = mock_account
+
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ resource.get("github")
+
+ assert mock_account.status == AccountStatus.ACTIVE.value
+ assert mock_account.initialized_at is not None
+ mock_db.session.commit.assert_called_once()
+
+ @patch("controllers.console.auth.oauth.dify_config")
+ @patch("controllers.console.auth.oauth.get_oauth_providers")
+ @patch("controllers.console.auth.oauth._generate_account")
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.redirect")
+ def test_defensive_check_for_closed_account_status(
+ self,
+ mock_redirect,
+ mock_account_service,
+ mock_tenant_service,
+ mock_db,
+ mock_generate_account,
+ mock_get_providers,
+ mock_config,
+ resource,
+ app,
+ oauth_setup,
+ ):
+ """Defensive test for CLOSED account status handling in OAuth callback.
+
+ This is a defensive test documenting expected security behavior for CLOSED accounts.
+
+ Current behavior: CLOSED status is NOT checked, allowing closed accounts to login.
+ Expected behavior: CLOSED accounts should be rejected like BANNED accounts.
+
+ Context:
+ - AccountStatus.CLOSED is defined in the enum but never used in production
+ - The close_account() method exists but is never called
+ - Account deletion uses external service instead of status change
+ - All authentication services (OAuth, password, email) don't check CLOSED status
+
+ TODO: If CLOSED status is implemented in the future:
+ 1. Update OAuth callback to check for CLOSED status
+ 2. Add similar checks to all authentication services for consistency
+ 3. Update this test to verify the rejection behavior
+
+ Security consideration: Until properly implemented, CLOSED status provides no protection.
+ """
+ # Setup
+ mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
+ mock_get_providers.return_value = {"github": oauth_setup["provider"]}
+
+ # Create account with CLOSED status
+ closed_account = MagicMock()
+ closed_account.status = AccountStatus.CLOSED.value
+ closed_account.id = "123"
+ closed_account.name = "Closed Account"
+ mock_generate_account.return_value = closed_account
+
+ # Mock successful login (current behavior)
+ mock_token_pair = MagicMock()
+ mock_token_pair.access_token = "jwt_access_token"
+ mock_token_pair.refresh_token = "jwt_refresh_token"
+ mock_account_service.login.return_value = mock_token_pair
+
+ # Execute OAuth callback
+ with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
+ resource.get("github")
+
+ # Verify current behavior: login succeeds (this is NOT ideal)
+ mock_redirect.assert_called_once_with(
+ "http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
+ )
+ mock_account_service.login.assert_called_once()
+
+ # Document expected behavior in comments:
+ # Expected: mock_redirect.assert_called_once_with(
+ # "http://localhost:3000/signin?message=Account is closed."
+ # )
+ # Expected: mock_account_service.login.assert_not_called()
+
+
+class TestAccountGeneration:
+ @pytest.fixture
+ def user_info(self):
+ return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
+
+ @pytest.fixture
+ def mock_account(self):
+ account = MagicMock()
+ account.name = "Test User"
+ return account
+
+ @patch("controllers.console.auth.oauth.db")
+ @patch("controllers.console.auth.oauth.Account")
+ @patch("controllers.console.auth.oauth.Session")
+ @patch("controllers.console.auth.oauth.select")
+ def test_should_get_account_by_openid_or_email(
+ self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
+ ):
+ # Mock db.engine for Session creation
+ mock_db.engine = MagicMock()
+
+ # Test OpenID found
+ mock_account_model.get_by_openid.return_value = mock_account
+ result = _get_account_by_openid_or_email("github", user_info)
+ assert result == mock_account
+ mock_account_model.get_by_openid.assert_called_once_with("github", "123")
+
+ # Test fallback to email
+ mock_account_model.get_by_openid.return_value = None
+ mock_session_instance = MagicMock()
+ mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
+ mock_session.return_value.__enter__.return_value = mock_session_instance
+
+ result = _get_account_by_openid_or_email("github", user_info)
+ assert result == mock_account
+
+ @pytest.mark.parametrize(
+ ("allow_register", "existing_account", "should_create"),
+ [
+ (True, None, True), # New account creation allowed
+ (True, "existing", False), # Existing account
+ (False, None, False), # Registration not allowed
+ ],
+ )
+ @patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
+ @patch("controllers.console.auth.oauth.FeatureService")
+ @patch("controllers.console.auth.oauth.RegisterService")
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.db")
+ def test_should_handle_account_generation_scenarios(
+ self,
+ mock_db,
+ mock_tenant_service,
+ mock_account_service,
+ mock_register_service,
+ mock_feature_service,
+ mock_get_account,
+ app,
+ user_info,
+ mock_account,
+ allow_register,
+ existing_account,
+ should_create,
+ ):
+ mock_get_account.return_value = mock_account if existing_account else None
+ mock_feature_service.get_system_features.return_value.is_allow_register = allow_register
+ mock_register_service.register.return_value = mock_account
+
+ with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
+ if not allow_register and not existing_account:
+ with pytest.raises(AccountNotFoundError):
+ _generate_account("github", user_info)
+ else:
+ result = _generate_account("github", user_info)
+ assert result == mock_account
+
+ if should_create:
+ mock_register_service.register.assert_called_once_with(
+ email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
+ )
+
+ @patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
+ @patch("controllers.console.auth.oauth.TenantService")
+ @patch("controllers.console.auth.oauth.FeatureService")
+ @patch("controllers.console.auth.oauth.AccountService")
+ @patch("controllers.console.auth.oauth.tenant_was_created")
+ def test_should_create_workspace_for_account_without_tenant(
+ self,
+ mock_event,
+ mock_account_service,
+ mock_feature_service,
+ mock_tenant_service,
+ mock_get_account,
+ app,
+ user_info,
+ mock_account,
+ ):
+ mock_get_account.return_value = mock_account
+ mock_tenant_service.get_join_tenants.return_value = []
+ mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
+
+ mock_new_tenant = MagicMock()
+ mock_tenant_service.create_tenant.return_value = mock_new_tenant
+
+ with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
+ result = _generate_account("github", user_info)
+
+ assert result == mock_account
+ mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
+ mock_tenant_service.create_tenant_member.assert_called_once_with(
+ mock_new_tenant, mock_account, role="owner"
+ )
+ mock_event.send.assert_called_once_with(mock_new_tenant)
diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py
new file mode 100644
index 0000000000..9742368f04
--- /dev/null
+++ b/api/tests/unit_tests/controllers/console/test_wraps.py
@@ -0,0 +1,380 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from flask import Flask
+from flask_login import LoginManager, UserMixin
+
+from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
+from controllers.console.workspace.error import AccountNotInitializedError
+from controllers.console.wraps import (
+ account_initialization_required,
+ cloud_edition_billing_rate_limit_check,
+ cloud_edition_billing_resource_check,
+ enterprise_license_required,
+ only_edition_cloud,
+ only_edition_enterprise,
+ only_edition_self_hosted,
+ setup_required,
+)
+from models.account import AccountStatus
+from services.feature_service import LicenseStatus
+
+
+class MockUser(UserMixin):
+ """Simple User class for testing."""
+
+ def __init__(self, user_id: str):
+ self.id = user_id
+ self.current_tenant_id = "tenant123"
+
+ def get_id(self) -> str:
+ return self.id
+
+
+def create_app_with_login():
+ """Create a Flask app with LoginManager configured."""
+ app = Flask(__name__)
+ app.config["SECRET_KEY"] = "test-secret-key"
+
+ login_manager = LoginManager()
+ login_manager.init_app(app)
+
+ @login_manager.user_loader
+ def load_user(user_id: str):
+ return MockUser(user_id)
+
+ return app
+
+
+class TestAccountInitialization:
+ """Test account initialization decorator"""
+
+ def test_should_allow_initialized_account(self):
+ """Test that initialized accounts can access protected views"""
+ # Arrange
+ mock_user = MagicMock()
+ mock_user.status = AccountStatus.ACTIVE
+
+ @account_initialization_required
+ def protected_view():
+ return "success"
+
+ # Act
+ with patch("controllers.console.wraps.current_user", mock_user):
+ result = protected_view()
+
+ # Assert
+ assert result == "success"
+
+ def test_should_reject_uninitialized_account(self):
+ """Test that uninitialized accounts raise AccountNotInitializedError"""
+ # Arrange
+ mock_user = MagicMock()
+ mock_user.status = AccountStatus.UNINITIALIZED
+
+ @account_initialization_required
+ def protected_view():
+ return "success"
+
+ # Act & Assert
+ with patch("controllers.console.wraps.current_user", mock_user):
+ with pytest.raises(AccountNotInitializedError):
+ protected_view()
+
+
+class TestEditionChecks:
+ """Test edition-specific decorators"""
+
+ def test_only_edition_cloud_allows_cloud_edition(self):
+ """Test cloud edition decorator allows CLOUD edition"""
+
+ # Arrange
+ @only_edition_cloud
+ def cloud_view():
+ return "cloud_success"
+
+ # Act
+ with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"):
+ result = cloud_view()
+
+ # Assert
+ assert result == "cloud_success"
+
+ def test_only_edition_cloud_rejects_other_editions(self):
+ """Test cloud edition decorator rejects non-CLOUD editions"""
+ # Arrange
+ app = Flask(__name__)
+
+ @only_edition_cloud
+ def cloud_view():
+ return "cloud_success"
+
+ # Act & Assert
+ with app.test_request_context():
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ with pytest.raises(Exception) as exc_info:
+ cloud_view()
+ assert exc_info.value.code == 404
+
+ def test_only_edition_enterprise_allows_when_enabled(self):
+ """Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True"""
+
+ # Arrange
+ @only_edition_enterprise
+ def enterprise_view():
+ return "enterprise_success"
+
+ # Act
+ with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True):
+ result = enterprise_view()
+
+ # Assert
+ assert result == "enterprise_success"
+
+ def test_only_edition_self_hosted_allows_self_hosted(self):
+ """Test self-hosted edition decorator allows SELF_HOSTED edition"""
+
+ # Arrange
+ @only_edition_self_hosted
+ def self_hosted_view():
+ return "self_hosted_success"
+
+ # Act
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ result = self_hosted_view()
+
+ # Assert
+ assert result == "self_hosted_success"
+
+
+class TestBillingResourceLimits:
+ """Test billing resource limit decorators"""
+
+ def test_should_allow_when_under_resource_limit(self):
+ """Test that requests are allowed when under resource limits"""
+ # Arrange
+ mock_features = MagicMock()
+ mock_features.billing.enabled = True
+ mock_features.members.limit = 10
+ mock_features.members.size = 5
+
+ @cloud_edition_billing_resource_check("members")
+ def add_member():
+ return "member_added"
+
+ # Act
+ with patch("controllers.console.wraps.current_user"):
+ with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
+ result = add_member()
+
+ # Assert
+ assert result == "member_added"
+
+ def test_should_reject_when_over_resource_limit(self):
+ """Test that requests are rejected when over resource limits"""
+ # Arrange
+ app = create_app_with_login()
+ mock_features = MagicMock()
+ mock_features.billing.enabled = True
+ mock_features.members.limit = 10
+ mock_features.members.size = 10
+
+ @cloud_edition_billing_resource_check("members")
+ def add_member():
+ return "member_added"
+
+ # Act & Assert
+ with app.test_request_context():
+ with patch("controllers.console.wraps.current_user", MockUser("test_user")):
+ with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
+ with pytest.raises(Exception) as exc_info:
+ add_member()
+ assert exc_info.value.code == 403
+ assert "members has reached the limit" in str(exc_info.value.description)
+
+ def test_should_check_source_for_documents_limit(self):
+ """Test document limit checks request source"""
+ # Arrange
+ app = create_app_with_login()
+ mock_features = MagicMock()
+ mock_features.billing.enabled = True
+ mock_features.documents_upload_quota.limit = 100
+ mock_features.documents_upload_quota.size = 100
+
+ @cloud_edition_billing_resource_check("documents")
+ def upload_document():
+ return "document_uploaded"
+
+ # Test 1: Should reject when source is datasets
+ with app.test_request_context("/?source=datasets"):
+ with patch("controllers.console.wraps.current_user", MockUser("test_user")):
+ with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
+ with pytest.raises(Exception) as exc_info:
+ upload_document()
+ assert exc_info.value.code == 403
+
+ # Test 2: Should allow when source is not datasets
+ with app.test_request_context("/?source=other"):
+ with patch("controllers.console.wraps.current_user", MockUser("test_user")):
+ with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
+ result = upload_document()
+ assert result == "document_uploaded"
+
+
+class TestRateLimiting:
+ """Test rate limiting decorator"""
+
+ @patch("controllers.console.wraps.redis_client")
+ @patch("controllers.console.wraps.db")
+ def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
+ """Test that requests within rate limit are allowed"""
+ # Arrange
+ mock_rate_limit = MagicMock()
+ mock_rate_limit.enabled = True
+ mock_rate_limit.limit = 10
+ mock_redis.zcard.return_value = 5 # 5 requests in window
+
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ def knowledge_request():
+ return "knowledge_success"
+
+ # Act
+ with patch("controllers.console.wraps.current_user"):
+ with patch(
+ "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
+ ):
+ result = knowledge_request()
+
+ # Assert
+ assert result == "knowledge_success"
+ mock_redis.zadd.assert_called_once()
+ mock_redis.zremrangebyscore.assert_called_once()
+
+ @patch("controllers.console.wraps.redis_client")
+ @patch("controllers.console.wraps.db")
+ def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
+ """Test that requests over rate limit are rejected and logged"""
+ # Arrange
+ app = create_app_with_login()
+ mock_rate_limit = MagicMock()
+ mock_rate_limit.enabled = True
+ mock_rate_limit.limit = 10
+ mock_rate_limit.subscription_plan = "pro"
+ mock_redis.zcard.return_value = 11 # Over limit
+
+ mock_session = MagicMock()
+ mock_db.session = mock_session
+
+ @cloud_edition_billing_rate_limit_check("knowledge")
+ def knowledge_request():
+ return "knowledge_success"
+
+ # Act & Assert
+ with app.test_request_context():
+ with patch("controllers.console.wraps.current_user", MockUser("test_user")):
+ with patch(
+ "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
+ ):
+ with pytest.raises(Exception) as exc_info:
+ knowledge_request()
+
+ # Verify error
+ assert exc_info.value.code == 403
+ assert "rate limit" in str(exc_info.value.description)
+
+ # Verify rate limit log was created
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_called_once()
+
+
+class TestSystemSetup:
+ """Test system setup decorator"""
+
+ @patch("controllers.console.wraps.db")
+ def test_should_allow_when_setup_complete(self, mock_db):
+ """Test that requests are allowed when setup is complete"""
+ # Arrange
+ mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
+
+ @setup_required
+ def admin_view():
+ return "admin_success"
+
+ # Act
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ result = admin_view()
+
+ # Assert
+ assert result == "admin_success"
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.wraps.os.environ.get")
+ def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
+ """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
+ # Arrange
+ mock_db.session.query.return_value.first.return_value = None # No setup
+ mock_environ_get.return_value = "some_password"
+
+ @setup_required
+ def admin_view():
+ return "admin_success"
+
+ # Act & Assert
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ with pytest.raises(NotInitValidateError):
+ admin_view()
+
+ @patch("controllers.console.wraps.db")
+ @patch("controllers.console.wraps.os.environ.get")
+ def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
+ """Test NotSetupError when no INIT_PASSWORD and setup not complete"""
+ # Arrange
+ mock_db.session.query.return_value.first.return_value = None # No setup
+ mock_environ_get.return_value = None # No INIT_PASSWORD
+
+ @setup_required
+ def admin_view():
+ return "admin_success"
+
+ # Act & Assert
+ with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
+ with pytest.raises(NotSetupError):
+ admin_view()
+
+
+class TestEnterpriseLicense:
+ """Test enterprise license decorator"""
+
+ def test_should_allow_with_valid_license(self):
+ """Test that valid licenses allow access"""
+ # Arrange
+ mock_settings = MagicMock()
+ mock_settings.license.status = LicenseStatus.ACTIVE
+
+ @enterprise_license_required
+ def enterprise_feature():
+ return "enterprise_success"
+
+ # Act
+ with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
+ result = enterprise_feature()
+
+ # Assert
+ assert result == "enterprise_success"
+
+ @pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST])
+ def test_should_reject_with_invalid_license(self, invalid_status):
+ """Test that invalid licenses raise UnauthorizedAndForceLogout"""
+ # Arrange
+ mock_settings = MagicMock()
+ mock_settings.license.status = invalid_status
+
+ @enterprise_license_required
+ def enterprise_feature():
+ return "enterprise_success"
+
+ # Act & Assert
+ with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
+ with pytest.raises(UnauthorizedAndForceLogout) as exc_info:
+ enterprise_feature()
+ assert "license is invalid" in str(exc_info.value)
diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py
new file mode 100644
index 0000000000..5890009742
--- /dev/null
+++ b/api/tests/unit_tests/core/helper/test_encrypter.py
@@ -0,0 +1,280 @@
+import base64
+import binascii
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from core.helper.encrypter import (
+ batch_decrypt_token,
+ decrypt_token,
+ encrypt_token,
+ get_decrypt_decoding,
+ obfuscated_token,
+)
+from libs.rsa import PrivkeyNotFoundError
+
+
+class TestObfuscatedToken:
+ @pytest.mark.parametrize(
+ ("token", "expected"),
+ [
+ ("", ""), # Empty token
+ ("1234567", "*" * 20), # Short token (<8 chars)
+ ("12345678", "*" * 20), # Boundary case (8 chars)
+ ("123456789abcdef", "123456" + "*" * 12 + "ef"), # Long token
+ ("abc!@#$%^&*()def", "abc!@#" + "*" * 12 + "ef"), # Special chars
+ ],
+ )
+ def test_obfuscation_logic(self, token, expected):
+ """Test core obfuscation logic for various token lengths"""
+ assert obfuscated_token(token) == expected
+
+ def test_sensitive_data_protection(self):
+ """Ensure obfuscation never reveals full sensitive data"""
+ token = "api_key_secret_12345"
+ obfuscated = obfuscated_token(token)
+ assert token not in obfuscated
+ assert "*" * 12 in obfuscated
+
+
+class TestEncryptToken:
+ @patch("models.engine.db.session.query")
+ @patch("libs.rsa.encrypt")
+ def test_successful_encryption(self, mock_encrypt, mock_query):
+ """Test successful token encryption"""
+ mock_tenant = MagicMock()
+ mock_tenant.encrypt_public_key = "mock_public_key"
+ mock_query.return_value.where.return_value.first.return_value = mock_tenant
+ mock_encrypt.return_value = b"encrypted_data"
+
+ result = encrypt_token("tenant-123", "test_token")
+
+ assert result == base64.b64encode(b"encrypted_data").decode()
+ mock_encrypt.assert_called_with("test_token", "mock_public_key")
+
+ @patch("models.engine.db.session.query")
+ def test_tenant_not_found(self, mock_query):
+ """Test error when tenant doesn't exist"""
+ mock_query.return_value.where.return_value.first.return_value = None
+
+ with pytest.raises(ValueError) as exc_info:
+ encrypt_token("invalid-tenant", "test_token")
+
+ assert "Tenant with id invalid-tenant not found" in str(exc_info.value)
+
+
+class TestDecryptToken:
+ @patch("libs.rsa.decrypt")
+ def test_successful_decryption(self, mock_decrypt):
+ """Test successful token decryption"""
+ mock_decrypt.return_value = "decrypted_token"
+ encrypted_data = base64.b64encode(b"encrypted_data").decode()
+
+ result = decrypt_token("tenant-123", encrypted_data)
+
+ assert result == "decrypted_token"
+ mock_decrypt.assert_called_once_with(b"encrypted_data", "tenant-123")
+
+ def test_invalid_base64(self):
+ """Test handling of invalid base64 input"""
+ with pytest.raises(binascii.Error):
+ decrypt_token("tenant-123", "invalid_base64!!!")
+
+
+class TestBatchDecryptToken:
+ @patch("libs.rsa.get_decrypt_decoding")
+ @patch("libs.rsa.decrypt_token_with_decoding")
+ def test_batch_decryption(self, mock_decrypt_with_decoding, mock_get_decoding):
+ """Test batch decryption functionality"""
+ mock_rsa_key = MagicMock()
+ mock_cipher_rsa = MagicMock()
+ mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
+
+ # Test multiple tokens
+ mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3"]
+ tokens = [
+ base64.b64encode(b"encrypted1").decode(),
+ base64.b64encode(b"encrypted2").decode(),
+ base64.b64encode(b"encrypted3").decode(),
+ ]
+ result = batch_decrypt_token("tenant-123", tokens)
+
+ assert result == ["token1", "token2", "token3"]
+ # Key should only be loaded once
+ mock_get_decoding.assert_called_once_with("tenant-123")
+
+
+class TestGetDecryptDecoding:
+ @patch("extensions.ext_redis.redis_client.get")
+ @patch("extensions.ext_storage.storage.load")
+ def test_private_key_not_found(self, mock_storage_load, mock_redis_get):
+ """Test error when private key file doesn't exist"""
+ mock_redis_get.return_value = None
+ mock_storage_load.side_effect = FileNotFoundError()
+
+ with pytest.raises(PrivkeyNotFoundError) as exc_info:
+ get_decrypt_decoding("tenant-123")
+
+ assert "Private key not found, tenant_id: tenant-123" in str(exc_info.value)
+
+
+class TestEncryptDecryptIntegration:
+ @patch("models.engine.db.session.query")
+ @patch("libs.rsa.encrypt")
+ @patch("libs.rsa.decrypt")
+ def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
+ """Test that encryption and decryption are consistent"""
+ # Setup mock tenant
+ mock_tenant = MagicMock()
+ mock_tenant.encrypt_public_key = "mock_public_key"
+ mock_query.return_value.where.return_value.first.return_value = mock_tenant
+
+ # Setup mock encryption/decryption
+ original_token = "test_token_123"
+ mock_encrypt.return_value = b"encrypted_data"
+ mock_decrypt.return_value = original_token
+
+ # Test encryption
+ encrypted = encrypt_token("tenant-123", original_token)
+
+ # Test decryption
+ decrypted = decrypt_token("tenant-123", encrypted)
+
+ assert decrypted == original_token
+
+
+class TestSecurity:
+ """Critical security tests for encryption system"""
+
+ @patch("models.engine.db.session.query")
+ @patch("libs.rsa.encrypt")
+ def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
+ """Ensure tokens encrypted for one tenant cannot be used by another"""
+ # Setup mock tenant
+ mock_tenant = MagicMock()
+ mock_tenant.encrypt_public_key = "tenant1_public_key"
+ mock_query.return_value.where.return_value.first.return_value = mock_tenant
+ mock_encrypt.return_value = b"encrypted_for_tenant1"
+
+ # Encrypt token for tenant1
+ encrypted = encrypt_token("tenant-123", "sensitive_data")
+
+ # Attempt to decrypt with different tenant should fail
+ with patch("libs.rsa.decrypt") as mock_decrypt:
+ mock_decrypt.side_effect = Exception("Invalid tenant key")
+
+ with pytest.raises(Exception, match="Invalid tenant key"):
+ decrypt_token("different-tenant", encrypted)
+
+ @patch("libs.rsa.decrypt")
+ def test_tampered_ciphertext_rejection(self, mock_decrypt):
+ """Detect and reject tampered ciphertext"""
+ valid_encrypted = base64.b64encode(b"valid_data").decode()
+
+ # Tamper with ciphertext
+ tampered_bytes = bytearray(base64.b64decode(valid_encrypted))
+ tampered_bytes[0] ^= 0xFF
+ tampered = base64.b64encode(bytes(tampered_bytes)).decode()
+
+ mock_decrypt.side_effect = Exception("Decryption error")
+
+ with pytest.raises(Exception, match="Decryption error"):
+ decrypt_token("tenant-123", tampered)
+
+ @patch("models.engine.db.session.query")
+ @patch("libs.rsa.encrypt")
+ def test_encryption_randomness(self, mock_encrypt, mock_query):
+ """Ensure same plaintext produces different ciphertext"""
+ mock_tenant = MagicMock(encrypt_public_key="key")
+ mock_query.return_value.where.return_value.first.return_value = mock_tenant
+
+ # Different outputs for same input
+ mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
+
+ results = [encrypt_token("tenant-123", "token") for _ in range(3)]
+
+ # All results should be different
+ assert len(set(results)) == 3
+
+
+class TestEdgeCases:
+ """Additional security-focused edge case tests"""
+
+ def test_should_handle_empty_string_in_obfuscation(self):
+ """Test handling of empty string in obfuscation"""
+ # Test empty string (which is a valid str type)
+ assert obfuscated_token("") == ""
+
+ @patch("models.engine.db.session.query")
+ @patch("libs.rsa.encrypt")
+ def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
+ """Test encryption of empty token"""
+ mock_tenant = MagicMock()
+ mock_tenant.encrypt_public_key = "mock_public_key"
+ mock_query.return_value.where.return_value.first.return_value = mock_tenant
+ mock_encrypt.return_value = b"encrypted_empty"
+
+ result = encrypt_token("tenant-123", "")
+
+ assert result == base64.b64encode(b"encrypted_empty").decode()
+ mock_encrypt.assert_called_with("", "mock_public_key")
+
+ @patch("models.engine.db.session.query")
+ @patch("libs.rsa.encrypt")
+ def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
+ """Test tokens containing special/unicode characters"""
+ mock_tenant = MagicMock()
+ mock_tenant.encrypt_public_key = "mock_public_key"
+ mock_query.return_value.where.return_value.first.return_value = mock_tenant
+ mock_encrypt.return_value = b"encrypted_special"
+
+ # Test various special characters
+ special_tokens = [
+ "token\x00with\x00null", # Null bytes
+ "token_with_emoji_😀🎉", # Unicode emoji
+ "token\nwith\nnewlines", # Newlines
+ "token\twith\ttabs", # Tabs
+ "token_with_中文字符", # Chinese characters
+ ]
+
+ for token in special_tokens:
+ result = encrypt_token("tenant-123", token)
+ assert result == base64.b64encode(b"encrypted_special").decode()
+ mock_encrypt.assert_called_with(token, "mock_public_key")
+
+ @patch("models.engine.db.session.query")
+ @patch("libs.rsa.encrypt")
+ def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
+ """Test behavior when token exceeds RSA encryption limits"""
+ mock_tenant = MagicMock()
+ mock_tenant.encrypt_public_key = "mock_public_key"
+ mock_query.return_value.where.return_value.first.return_value = mock_tenant
+
+ # RSA 2048-bit can only encrypt ~245 bytes
+ # The actual limit depends on padding scheme
+ mock_encrypt.side_effect = ValueError("Message too long for RSA key size")
+
+ # Create a token that would exceed RSA limits
+ long_token = "x" * 300
+
+ with pytest.raises(ValueError, match="Message too long for RSA key size"):
+ encrypt_token("tenant-123", long_token)
+
+ @patch("libs.rsa.get_decrypt_decoding")
+ @patch("libs.rsa.decrypt_token_with_decoding")
+ def test_batch_decrypt_loads_key_only_once(self, mock_decrypt_with_decoding, mock_get_decoding):
+ """Verify batch decryption optimization - loads key only once"""
+ mock_rsa_key = MagicMock()
+ mock_cipher_rsa = MagicMock()
+ mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
+
+ # Test with multiple tokens
+ mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3", "token4", "token5"]
+ tokens = [base64.b64encode(f"encrypted{i}".encode()).decode() for i in range(5)]
+
+ result = batch_decrypt_token("tenant-123", tokens)
+
+ assert result == ["token1", "token2", "token3", "token4", "token5"]
+ # Key should only be loaded once regardless of token count
+ mock_get_decoding.assert_called_once_with("tenant-123")
+ assert mock_decrypt_with_decoding.call_count == 5
diff --git a/api/tests/unit_tests/core/helper/test_trace_id_helper.py b/api/tests/unit_tests/core/helper/test_trace_id_helper.py
new file mode 100644
index 0000000000..27bfe1af05
--- /dev/null
+++ b/api/tests/unit_tests/core/helper/test_trace_id_helper.py
@@ -0,0 +1,86 @@
+import pytest
+
+from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id
+
+
+class DummyRequest:
+ def __init__(self, headers=None, args=None, json=None, is_json=False):
+ self.headers = headers or {}
+ self.args = args or {}
+ self.json = json
+ self.is_json = is_json
+
+
+class TestTraceIdHelper:
+ """Test cases for trace_id_helper.py"""
+
+ @pytest.mark.parametrize(
+ ("trace_id", "expected"),
+ [
+ ("abc123", True),
+ ("A-B_C-123", True),
+ ("a" * 128, True),
+ ("", False),
+ ("a" * 129, False),
+ ("abc!@#", False),
+ ("空格", False),
+ ("with space", False),
+ ],
+ )
+ def test_is_valid_trace_id(self, trace_id, expected):
+ """Test trace_id validation for various cases"""
+ assert is_valid_trace_id(trace_id) is expected
+
+ def test_get_external_trace_id_from_header(self):
+ """Should extract valid trace_id from header"""
+ req = DummyRequest(headers={"X-Trace-Id": "abc123"})
+ assert get_external_trace_id(req) == "abc123"
+
+ def test_get_external_trace_id_from_args(self):
+ """Should extract valid trace_id from args if header missing"""
+ req = DummyRequest(args={"trace_id": "abc123"})
+ assert get_external_trace_id(req) == "abc123"
+
+ def test_get_external_trace_id_from_json(self):
+ """Should extract valid trace_id from JSON body if header and args missing"""
+ req = DummyRequest(is_json=True, json={"trace_id": "abc123"})
+ assert get_external_trace_id(req) == "abc123"
+
+ def test_get_external_trace_id_priority(self):
+ """Header > args > json priority"""
+ req = DummyRequest(
+ headers={"X-Trace-Id": "header_id"},
+ args={"trace_id": "args_id"},
+ is_json=True,
+ json={"trace_id": "json_id"},
+ )
+ assert get_external_trace_id(req) == "header_id"
+ req2 = DummyRequest(args={"trace_id": "args_id"}, is_json=True, json={"trace_id": "json_id"})
+ assert get_external_trace_id(req2) == "args_id"
+ req3 = DummyRequest(is_json=True, json={"trace_id": "json_id"})
+ assert get_external_trace_id(req3) == "json_id"
+
+ @pytest.mark.parametrize(
+ "req",
+ [
+ DummyRequest(headers={"X-Trace-Id": "!!!"}),
+ DummyRequest(args={"trace_id": "!!!"}),
+ DummyRequest(is_json=True, json={"trace_id": "!!!"}),
+ DummyRequest(),
+ ],
+ )
+ def test_get_external_trace_id_invalid(self, req):
+ """Should return None for invalid or missing trace_id"""
+ assert get_external_trace_id(req) is None
+
+ @pytest.mark.parametrize(
+ ("args", "expected"),
+ [
+ ({"external_trace_id": "abc123"}, {"external_trace_id": "abc123"}),
+ ({"other": "value"}, {}),
+ ({}, {}),
+ ],
+ )
+ def test_extract_external_trace_id_from_args(self, args, expected):
+ """Test extraction of external_trace_id from args mapping"""
+ assert extract_external_trace_id_from_args(args) == expected
diff --git a/api/tests/unit_tests/core/mcp/client/test_session.py b/api/tests/unit_tests/core/mcp/client/test_session.py
new file mode 100644
index 0000000000..c84169bf15
--- /dev/null
+++ b/api/tests/unit_tests/core/mcp/client/test_session.py
@@ -0,0 +1,471 @@
+import queue
+import threading
+from typing import Any
+
+from core.mcp import types
+from core.mcp.entities import RequestContext
+from core.mcp.session.base_session import RequestResponder
+from core.mcp.session.client_session import DEFAULT_CLIENT_INFO, ClientSession
+from core.mcp.types import (
+ LATEST_PROTOCOL_VERSION,
+ ClientNotification,
+ ClientRequest,
+ Implementation,
+ InitializedNotification,
+ InitializeRequest,
+ InitializeResult,
+ JSONRPCMessage,
+ JSONRPCNotification,
+ JSONRPCRequest,
+ JSONRPCResponse,
+ ServerCapabilities,
+ ServerResult,
+ SessionMessage,
+)
+
+
+def test_client_session_initialize():
+ # Create synchronous queues to replace async streams
+ client_to_server: queue.Queue[SessionMessage] = queue.Queue()
+ server_to_client: queue.Queue[SessionMessage] = queue.Queue()
+
+ initialized_notification = None
+
+ def mock_server():
+ nonlocal initialized_notification
+
+ # Receive initialization request
+ session_message = client_to_server.get(timeout=5.0)
+ jsonrpc_request = session_message.message
+ assert isinstance(jsonrpc_request.root, JSONRPCRequest)
+ request = ClientRequest.model_validate(
+ jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+ assert isinstance(request.root, InitializeRequest)
+
+ # Create response
+ result = ServerResult(
+ InitializeResult(
+ protocolVersion=LATEST_PROTOCOL_VERSION,
+ capabilities=ServerCapabilities(
+ logging=None,
+ resources=None,
+ tools=None,
+ experimental=None,
+ prompts=None,
+ ),
+ serverInfo=Implementation(name="mock-server", version="0.1.0"),
+ instructions="The server instructions.",
+ )
+ )
+
+ # Send response
+ server_to_client.put(
+ SessionMessage(
+ message=JSONRPCMessage(
+ JSONRPCResponse(
+ jsonrpc="2.0",
+ id=jsonrpc_request.root.id,
+ result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ )
+ )
+ )
+
+ # Receive initialized notification
+ session_notification = client_to_server.get(timeout=5.0)
+ jsonrpc_notification = session_notification.message
+ assert isinstance(jsonrpc_notification.root, JSONRPCNotification)
+ initialized_notification = ClientNotification.model_validate(
+ jsonrpc_notification.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+
+ # Create message handler
+ def message_handler(
+ message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
+ ) -> None:
+ if isinstance(message, Exception):
+ raise message
+
+ # Start mock server thread
+ server_thread = threading.Thread(target=mock_server, daemon=True)
+ server_thread.start()
+
+ # Create and use client session
+ with ClientSession(
+ server_to_client,
+ client_to_server,
+ message_handler=message_handler,
+ ) as session:
+ result = session.initialize()
+
+ # Wait for server thread to complete
+ server_thread.join(timeout=10.0)
+
+ # Assert results
+ assert isinstance(result, InitializeResult)
+ assert result.protocolVersion == LATEST_PROTOCOL_VERSION
+ assert isinstance(result.capabilities, ServerCapabilities)
+ assert result.serverInfo == Implementation(name="mock-server", version="0.1.0")
+ assert result.instructions == "The server instructions."
+
+ # Check that client sent initialized notification
+ assert initialized_notification
+ assert isinstance(initialized_notification.root, InitializedNotification)
+
+
+def test_client_session_custom_client_info():
+ # Create synchronous queues to replace async streams
+ client_to_server: queue.Queue[SessionMessage] = queue.Queue()
+ server_to_client: queue.Queue[SessionMessage] = queue.Queue()
+
+ custom_client_info = Implementation(name="test-client", version="1.2.3")
+ received_client_info = None
+
+ def mock_server():
+ nonlocal received_client_info
+
+ session_message = client_to_server.get(timeout=5.0)
+ jsonrpc_request = session_message.message
+ assert isinstance(jsonrpc_request.root, JSONRPCRequest)
+ request = ClientRequest.model_validate(
+ jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+ assert isinstance(request.root, InitializeRequest)
+ received_client_info = request.root.params.clientInfo
+
+ result = ServerResult(
+ InitializeResult(
+ protocolVersion=LATEST_PROTOCOL_VERSION,
+ capabilities=ServerCapabilities(),
+ serverInfo=Implementation(name="mock-server", version="0.1.0"),
+ )
+ )
+
+ server_to_client.put(
+ SessionMessage(
+ message=JSONRPCMessage(
+ JSONRPCResponse(
+ jsonrpc="2.0",
+ id=jsonrpc_request.root.id,
+ result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ )
+ )
+ )
+ # Receive initialized notification
+ client_to_server.get(timeout=5.0)
+
+ # Start mock server thread
+ server_thread = threading.Thread(target=mock_server, daemon=True)
+ server_thread.start()
+
+ with ClientSession(
+ server_to_client,
+ client_to_server,
+ client_info=custom_client_info,
+ ) as session:
+ session.initialize()
+
+ # Wait for server thread to complete
+ server_thread.join(timeout=10.0)
+
+ # Assert that custom client info was sent
+ assert received_client_info == custom_client_info
+
+
+def test_client_session_default_client_info():
+ # Create synchronous queues to replace async streams
+ client_to_server: queue.Queue[SessionMessage] = queue.Queue()
+ server_to_client: queue.Queue[SessionMessage] = queue.Queue()
+
+ received_client_info = None
+
+ def mock_server():
+ nonlocal received_client_info
+
+ session_message = client_to_server.get(timeout=5.0)
+ jsonrpc_request = session_message.message
+ assert isinstance(jsonrpc_request.root, JSONRPCRequest)
+ request = ClientRequest.model_validate(
+ jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+ assert isinstance(request.root, InitializeRequest)
+ received_client_info = request.root.params.clientInfo
+
+ result = ServerResult(
+ InitializeResult(
+ protocolVersion=LATEST_PROTOCOL_VERSION,
+ capabilities=ServerCapabilities(),
+ serverInfo=Implementation(name="mock-server", version="0.1.0"),
+ )
+ )
+
+ server_to_client.put(
+ SessionMessage(
+ message=JSONRPCMessage(
+ JSONRPCResponse(
+ jsonrpc="2.0",
+ id=jsonrpc_request.root.id,
+ result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ )
+ )
+ )
+ # Receive initialized notification
+ client_to_server.get(timeout=5.0)
+
+ # Start mock server thread
+ server_thread = threading.Thread(target=mock_server, daemon=True)
+ server_thread.start()
+
+ with ClientSession(
+ server_to_client,
+ client_to_server,
+ ) as session:
+ session.initialize()
+
+ # Wait for server thread to complete
+ server_thread.join(timeout=10.0)
+
+ # Assert that default client info was used
+ assert received_client_info == DEFAULT_CLIENT_INFO
+
+
+def test_client_session_version_negotiation_success():
+ # Create synchronous queues to replace async streams
+ client_to_server: queue.Queue[SessionMessage] = queue.Queue()
+ server_to_client: queue.Queue[SessionMessage] = queue.Queue()
+
+ def mock_server():
+ session_message = client_to_server.get(timeout=5.0)
+ jsonrpc_request = session_message.message
+ assert isinstance(jsonrpc_request.root, JSONRPCRequest)
+ request = ClientRequest.model_validate(
+ jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+ assert isinstance(request.root, InitializeRequest)
+
+ # Send supported protocol version
+ result = ServerResult(
+ InitializeResult(
+ protocolVersion=LATEST_PROTOCOL_VERSION,
+ capabilities=ServerCapabilities(),
+ serverInfo=Implementation(name="mock-server", version="0.1.0"),
+ )
+ )
+
+ server_to_client.put(
+ SessionMessage(
+ message=JSONRPCMessage(
+ JSONRPCResponse(
+ jsonrpc="2.0",
+ id=jsonrpc_request.root.id,
+ result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ )
+ )
+ )
+ # Receive initialized notification
+ client_to_server.get(timeout=5.0)
+
+ # Start mock server thread
+ server_thread = threading.Thread(target=mock_server, daemon=True)
+ server_thread.start()
+
+ with ClientSession(
+ server_to_client,
+ client_to_server,
+ ) as session:
+ result = session.initialize()
+
+ # Wait for server thread to complete
+ server_thread.join(timeout=10.0)
+
+ # Should successfully initialize
+ assert isinstance(result, InitializeResult)
+ assert result.protocolVersion == LATEST_PROTOCOL_VERSION
+
+
+def test_client_session_version_negotiation_failure():
+ # Create synchronous queues to replace async streams
+ client_to_server: queue.Queue[SessionMessage] = queue.Queue()
+ server_to_client: queue.Queue[SessionMessage] = queue.Queue()
+
+ def mock_server():
+ session_message = client_to_server.get(timeout=5.0)
+ jsonrpc_request = session_message.message
+ assert isinstance(jsonrpc_request.root, JSONRPCRequest)
+ request = ClientRequest.model_validate(
+ jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+ assert isinstance(request.root, InitializeRequest)
+
+ # Send unsupported protocol version
+ result = ServerResult(
+ InitializeResult(
+ protocolVersion="99.99.99", # Unsupported version
+ capabilities=ServerCapabilities(),
+ serverInfo=Implementation(name="mock-server", version="0.1.0"),
+ )
+ )
+
+ server_to_client.put(
+ SessionMessage(
+ message=JSONRPCMessage(
+ JSONRPCResponse(
+ jsonrpc="2.0",
+ id=jsonrpc_request.root.id,
+ result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ )
+ )
+ )
+
+ # Start mock server thread
+ server_thread = threading.Thread(target=mock_server, daemon=True)
+ server_thread.start()
+
+ with ClientSession(
+ server_to_client,
+ client_to_server,
+ ) as session:
+ import pytest
+
+ with pytest.raises(RuntimeError, match="Unsupported protocol version"):
+ session.initialize()
+
+ # Wait for server thread to complete
+ server_thread.join(timeout=10.0)
+
+
+def test_client_capabilities_default():
+ # Create synchronous queues to replace async streams
+ client_to_server: queue.Queue[SessionMessage] = queue.Queue()
+ server_to_client: queue.Queue[SessionMessage] = queue.Queue()
+
+ received_capabilities = None
+
+ def mock_server():
+ nonlocal received_capabilities
+
+ session_message = client_to_server.get(timeout=5.0)
+ jsonrpc_request = session_message.message
+ assert isinstance(jsonrpc_request.root, JSONRPCRequest)
+ request = ClientRequest.model_validate(
+ jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+ assert isinstance(request.root, InitializeRequest)
+ received_capabilities = request.root.params.capabilities
+
+ result = ServerResult(
+ InitializeResult(
+ protocolVersion=LATEST_PROTOCOL_VERSION,
+ capabilities=ServerCapabilities(),
+ serverInfo=Implementation(name="mock-server", version="0.1.0"),
+ )
+ )
+
+ server_to_client.put(
+ SessionMessage(
+ message=JSONRPCMessage(
+ JSONRPCResponse(
+ jsonrpc="2.0",
+ id=jsonrpc_request.root.id,
+ result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ )
+ )
+ )
+ # Receive initialized notification
+ client_to_server.get(timeout=5.0)
+
+ # Start mock server thread
+ server_thread = threading.Thread(target=mock_server, daemon=True)
+ server_thread.start()
+
+ with ClientSession(
+ server_to_client,
+ client_to_server,
+ ) as session:
+ session.initialize()
+
+ # Wait for server thread to complete
+ server_thread.join(timeout=10.0)
+
+ # Assert default capabilities
+ assert received_capabilities is not None
+ assert received_capabilities.sampling is not None
+ assert received_capabilities.roots is not None
+ assert received_capabilities.roots.listChanged is True
+
+
+def test_client_capabilities_with_custom_callbacks():
+ # Create synchronous queues to replace async streams
+ client_to_server: queue.Queue[SessionMessage] = queue.Queue()
+ server_to_client: queue.Queue[SessionMessage] = queue.Queue()
+
+ def custom_sampling_callback(
+ context: RequestContext["ClientSession", Any],
+ params: types.CreateMessageRequestParams,
+ ) -> types.CreateMessageResult | types.ErrorData:
+ return types.CreateMessageResult(
+ model="test-model",
+ role="assistant",
+ content=types.TextContent(type="text", text="Custom response"),
+ )
+
+ def custom_list_roots_callback(
+ context: RequestContext["ClientSession", Any],
+ ) -> types.ListRootsResult | types.ErrorData:
+ return types.ListRootsResult(roots=[])
+
+ def mock_server():
+ session_message = client_to_server.get(timeout=5.0)
+ jsonrpc_request = session_message.message
+ assert isinstance(jsonrpc_request.root, JSONRPCRequest)
+ request = ClientRequest.model_validate(
+ jsonrpc_request.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+ assert isinstance(request.root, InitializeRequest)
+
+ result = ServerResult(
+ InitializeResult(
+ protocolVersion=LATEST_PROTOCOL_VERSION,
+ capabilities=ServerCapabilities(),
+ serverInfo=Implementation(name="mock-server", version="0.1.0"),
+ )
+ )
+
+ server_to_client.put(
+ SessionMessage(
+ message=JSONRPCMessage(
+ JSONRPCResponse(
+ jsonrpc="2.0",
+ id=jsonrpc_request.root.id,
+ result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ )
+ )
+ )
+ # Receive initialized notification
+ client_to_server.get(timeout=5.0)
+
+ # Start mock server thread
+ server_thread = threading.Thread(target=mock_server, daemon=True)
+ server_thread.start()
+
+ with ClientSession(
+ server_to_client,
+ client_to_server,
+ sampling_callback=custom_sampling_callback,
+ list_roots_callback=custom_list_roots_callback,
+ ) as session:
+ result = session.initialize()
+
+ # Wait for server thread to complete
+ server_thread.join(timeout=10.0)
+
+ # Verify initialization succeeded
+ assert isinstance(result, InitializeResult)
+ assert result.protocolVersion == LATEST_PROTOCOL_VERSION
diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py
new file mode 100644
index 0000000000..8122cd08eb
--- /dev/null
+++ b/api/tests/unit_tests/core/mcp/client/test_sse.py
@@ -0,0 +1,349 @@
+import json
+import queue
+import threading
+import time
+from typing import Any
+from unittest.mock import Mock, patch
+
+import httpx
+import pytest
+
+from core.mcp import types
+from core.mcp.client.sse_client import sse_client
+from core.mcp.error import MCPAuthError, MCPConnectionError
+
+SERVER_NAME = "test_server_for_SSE"
+
+
+def test_sse_message_id_coercion():
+ """Test that string message IDs that look like integers are parsed as integers.
+
+ See
for more details.
+ """
+ json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}'
+ msg = types.JSONRPCMessage.model_validate_json(json_message)
+ expected = types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123))
+
+ # Check if both are JSONRPCRequest instances
+ assert isinstance(msg.root, types.JSONRPCRequest)
+ assert isinstance(expected.root, types.JSONRPCRequest)
+
+ assert msg.root.id == expected.root.id
+ assert msg.root.method == expected.root.method
+ assert msg.root.jsonrpc == expected.root.jsonrpc
+
+
+class MockSSEClient:
+ """Mock SSE client for testing."""
+
+ def __init__(self, url: str, headers: dict[str, Any] | None = None):
+ self.url = url
+ self.headers = headers or {}
+ self.connected = False
+ self.read_queue: queue.Queue = queue.Queue()
+ self.write_queue: queue.Queue = queue.Queue()
+
+ def connect(self):
+ """Simulate connection establishment."""
+ self.connected = True
+
+ # Send endpoint event
+ endpoint_data = "/messages/?session_id=test-session-123"
+ self.read_queue.put(("endpoint", endpoint_data))
+
+ return self.read_queue, self.write_queue
+
+ def send_initialize_response(self):
+ """Send a mock initialize response."""
+ response = {
+ "jsonrpc": "2.0",
+ "id": 1,
+ "result": {
+ "protocolVersion": types.LATEST_PROTOCOL_VERSION,
+ "capabilities": {
+ "logging": None,
+ "resources": None,
+ "tools": None,
+ "experimental": None,
+ "prompts": None,
+ },
+ "serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
+ "instructions": "Test server instructions.",
+ },
+ }
+ self.read_queue.put(("message", json.dumps(response)))
+
+
+def test_sse_client_message_id_handling():
+ """Test SSE client properly handles message ID coercion."""
+ mock_client = MockSSEClient("http://test.example/sse")
+ read_queue, write_queue = mock_client.connect()
+
+ # Send a message with string ID that should be coerced to int
+ message_data = {
+ "jsonrpc": "2.0",
+ "id": "456", # String ID
+ "result": {"test": "data"},
+ }
+ read_queue.put(("message", json.dumps(message_data)))
+ read_queue.get(timeout=1.0)
+ # Get the message from queue
+ event_type, data = read_queue.get(timeout=1.0)
+ assert event_type == "message"
+
+ # Parse the message
+ parsed_message = types.JSONRPCMessage.model_validate_json(data)
+ # Check that it's a JSONRPCResponse and verify the ID
+ assert isinstance(parsed_message.root, types.JSONRPCResponse)
+ assert parsed_message.root.id == 456 # Should be converted to int
+
+
+def test_sse_client_connection_validation():
+ """Test SSE client validates endpoint URLs properly."""
+ test_url = "http://test.example/sse"
+
+ with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
+ # Mock the HTTP client
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ # Mock the SSE connection
+ mock_event_source = Mock()
+ mock_event_source.response.raise_for_status.return_value = None
+ mock_sse_connect.return_value.__enter__.return_value = mock_event_source
+
+ # Mock SSE events
+ class MockSSEEvent:
+ def __init__(self, event_type: str, data: str):
+ self.event = event_type
+ self.data = data
+
+ # Simulate endpoint event
+ endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123")
+ mock_event_source.iter_sse.return_value = [endpoint_event]
+
+ # Test connection
+ try:
+ with sse_client(test_url) as (read_queue, write_queue):
+ assert read_queue is not None
+ assert write_queue is not None
+ except Exception as e:
+ # Connection might fail due to mocking, but we're testing the validation logic
+ pass
+
+
+def test_sse_client_error_handling():
+ """Test SSE client properly handles various error conditions."""
+ test_url = "http://test.example/sse"
+
+ # Test 401 error handling
+ with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
+ # Mock 401 HTTP error
+ mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=Mock(status_code=401))
+ mock_sse_connect.side_effect = mock_error
+
+ with pytest.raises(MCPAuthError):
+ with sse_client(test_url):
+ pass
+
+ # Test other HTTP errors
+ with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
+ # Mock other HTTP error
+ mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=Mock(status_code=500))
+ mock_sse_connect.side_effect = mock_error
+
+ with pytest.raises(MCPConnectionError):
+ with sse_client(test_url):
+ pass
+
+
+def test_sse_client_timeout_configuration():
+ """Test SSE client timeout configuration."""
+ test_url = "http://test.example/sse"
+ custom_timeout = 10.0
+ custom_sse_timeout = 300.0
+ custom_headers = {"Authorization": "Bearer test-token"}
+
+ with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
+ # Mock successful connection
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ mock_event_source = Mock()
+ mock_event_source.response.raise_for_status.return_value = None
+ mock_event_source.iter_sse.return_value = []
+ mock_sse_connect.return_value.__enter__.return_value = mock_event_source
+
+ try:
+ with sse_client(
+ test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
+ ) as (read_queue, write_queue):
+ # Verify the configuration was passed correctly
+ mock_client_factory.assert_called_with(headers=custom_headers)
+
+ # Check that timeout was configured
+ call_args = mock_sse_connect.call_args
+ assert call_args is not None
+ timeout_arg = call_args[1]["timeout"]
+ assert timeout_arg.read == custom_sse_timeout
+ except Exception:
+ # Connection might fail due to mocking, but we tested the configuration
+ pass
+
+
+def test_sse_transport_endpoint_validation():
+ """Test SSE transport validates endpoint URLs correctly."""
+ from core.mcp.client.sse_client import SSETransport
+
+ transport = SSETransport("http://example.com/sse")
+
+ # Valid endpoint (same origin)
+ valid_endpoint = "http://example.com/messages/session123"
+ assert transport._validate_endpoint_url(valid_endpoint) == True
+
+ # Invalid endpoint (different origin)
+ invalid_endpoint = "http://malicious.com/messages/session123"
+ assert transport._validate_endpoint_url(invalid_endpoint) == False
+
+ # Invalid endpoint (different scheme)
+ invalid_scheme = "https://example.com/messages/session123"
+ assert transport._validate_endpoint_url(invalid_scheme) == False
+
+
+def test_sse_transport_message_parsing():
+ """Test SSE transport properly parses different message types."""
+ from core.mcp.client.sse_client import SSETransport
+
+ transport = SSETransport("http://example.com/sse")
+ read_queue: queue.Queue = queue.Queue()
+
+ # Test valid JSON-RPC message
+ valid_message = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}'
+ transport._handle_message_event(valid_message, read_queue)
+
+ # Should have a SessionMessage in the queue
+ message = read_queue.get(timeout=1.0)
+ assert message is not None
+ assert hasattr(message, "message")
+
+ # Test invalid JSON
+ invalid_json = '{"invalid": json}'
+ transport._handle_message_event(invalid_json, read_queue)
+
+ # Should have an exception in the queue
+ error = read_queue.get(timeout=1.0)
+ assert isinstance(error, Exception)
+
+
+def test_sse_client_queue_cleanup():
+ """Test that SSE client properly cleans up queues on exit."""
+ test_url = "http://test.example/sse"
+
+ read_queue = None
+ write_queue = None
+
+ with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
+ # Mock connection that raises an exception
+ mock_sse_connect.side_effect = Exception("Connection failed")
+
+ try:
+ with sse_client(test_url) as (rq, wq):
+ read_queue = rq
+ write_queue = wq
+ except Exception:
+ pass # Expected to fail
+
+ # Queues should be cleaned up even on exception
+ # Note: In real implementation, cleanup should put None to signal shutdown
+
+
+def test_sse_client_url_processing():
+ """Test SSE client URL processing functions."""
+ from core.mcp.client.sse_client import remove_request_params
+
+ # Test URL with parameters
+ url_with_params = "http://example.com/sse?param1=value1¶m2=value2"
+ cleaned_url = remove_request_params(url_with_params)
+ assert cleaned_url == "http://example.com/sse"
+
+ # Test URL without parameters
+ url_without_params = "http://example.com/sse"
+ cleaned_url = remove_request_params(url_without_params)
+ assert cleaned_url == "http://example.com/sse"
+
+ # Test URL with path and parameters
+ complex_url = "http://example.com/path/to/sse?session=123&token=abc"
+ cleaned_url = remove_request_params(complex_url)
+ assert cleaned_url == "http://example.com/path/to/sse"
+
+
+def test_sse_client_headers_propagation():
+ """Test that custom headers are properly propagated in SSE client."""
+ test_url = "http://test.example/sse"
+ custom_headers = {
+ "Authorization": "Bearer test-token",
+ "X-Custom-Header": "test-value",
+ "User-Agent": "test-client/1.0",
+ }
+
+ with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
+ # Mock the client factory to capture headers
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ # Mock the SSE connection
+ mock_event_source = Mock()
+ mock_event_source.response.raise_for_status.return_value = None
+ mock_event_source.iter_sse.return_value = []
+ mock_sse_connect.return_value.__enter__.return_value = mock_event_source
+
+ try:
+ with sse_client(test_url, headers=custom_headers):
+ pass
+ except Exception:
+ pass # Expected due to mocking
+
+ # Verify headers were passed to client factory
+ mock_client_factory.assert_called_with(headers=custom_headers)
+
+
+def test_sse_client_concurrent_access():
+ """Test SSE client behavior with concurrent queue access."""
+ test_read_queue: queue.Queue = queue.Queue()
+
+ # Simulate concurrent producers and consumers
+ def producer():
+ for i in range(10):
+ test_read_queue.put(f"message_{i}")
+ time.sleep(0.01) # Small delay to simulate real conditions
+
+ def consumer():
+ received = []
+ for _ in range(10):
+ try:
+ msg = test_read_queue.get(timeout=2.0)
+ received.append(msg)
+ except queue.Empty:
+ break
+ return received
+
+ # Start producer in separate thread
+ producer_thread = threading.Thread(target=producer, daemon=True)
+ producer_thread.start()
+
+ # Consume messages
+ received_messages = consumer()
+
+ # Wait for producer to finish
+ producer_thread.join(timeout=5.0)
+
+ # Verify all messages were received
+ assert len(received_messages) == 10
+ for i in range(10):
+ assert f"message_{i}" in received_messages
diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py
new file mode 100644
index 0000000000..9a30a35a49
--- /dev/null
+++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py
@@ -0,0 +1,450 @@
+"""
+Tests for the StreamableHTTP client transport.
+
+Contains tests for only the client side of the StreamableHTTP transport.
+"""
+
+import queue
+import threading
+import time
+from typing import Any
+from unittest.mock import Mock, patch
+
+from core.mcp import types
+from core.mcp.client.streamable_client import streamablehttp_client
+
+# Test constants
+SERVER_NAME = "test_streamable_http_server"
+TEST_SESSION_ID = "test-session-id-12345"
+INIT_REQUEST = {
+ "jsonrpc": "2.0",
+ "method": "initialize",
+ "params": {
+ "clientInfo": {"name": "test-client", "version": "1.0"},
+ "protocolVersion": "2025-03-26",
+ "capabilities": {},
+ },
+ "id": "init-1",
+}
+
+
+class MockStreamableHTTPClient:
+ """Mock StreamableHTTP client for testing."""
+
+ def __init__(self, url: str, headers: dict[str, Any] | None = None):
+ self.url = url
+ self.headers = headers or {}
+ self.connected = False
+ self.read_queue: queue.Queue = queue.Queue()
+ self.write_queue: queue.Queue = queue.Queue()
+ self.session_id = TEST_SESSION_ID
+
+ def connect(self):
+ """Simulate connection establishment."""
+ self.connected = True
+ return self.read_queue, self.write_queue, lambda: self.session_id
+
+ def send_initialize_response(self):
+ """Send a mock initialize response."""
+ session_message = types.SessionMessage(
+ message=types.JSONRPCMessage(
+ root=types.JSONRPCResponse(
+ jsonrpc="2.0",
+ id="init-1",
+ result={
+ "protocolVersion": types.LATEST_PROTOCOL_VERSION,
+ "capabilities": {
+ "logging": None,
+ "resources": None,
+ "tools": None,
+ "experimental": None,
+ "prompts": None,
+ },
+ "serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
+ "instructions": "Test server instructions.",
+ },
+ )
+ )
+ )
+ self.read_queue.put(session_message)
+
+ def send_tools_response(self):
+ """Send a mock tools list response."""
+ session_message = types.SessionMessage(
+ message=types.JSONRPCMessage(
+ root=types.JSONRPCResponse(
+ jsonrpc="2.0",
+ id="tools-1",
+ result={
+ "tools": [
+ {
+ "name": "test_tool",
+ "description": "A test tool",
+ "inputSchema": {"type": "object", "properties": {}},
+ }
+ ],
+ },
+ )
+ )
+ )
+ self.read_queue.put(session_message)
+
+
+def test_streamablehttp_client_message_id_handling():
+ """Test StreamableHTTP client properly handles message ID coercion."""
+ mock_client = MockStreamableHTTPClient("http://test.example/mcp")
+ read_queue, write_queue, get_session_id = mock_client.connect()
+
+ # Send a message with string ID that should be coerced to int
+ response_message = types.SessionMessage(
+ message=types.JSONRPCMessage(root=types.JSONRPCResponse(jsonrpc="2.0", id="789", result={"test": "data"}))
+ )
+ read_queue.put(response_message)
+
+ # Get the message from queue
+ message = read_queue.get(timeout=1.0)
+ assert message is not None
+ assert isinstance(message, types.SessionMessage)
+
+ # Check that the ID was properly handled
+ assert isinstance(message.message.root, types.JSONRPCResponse)
+ assert message.message.root.id == 789 # ID should be coerced to int due to union_mode="left_to_right"
+
+
+def test_streamablehttp_client_connection_validation():
+ """Test StreamableHTTP client validates connections properly."""
+ test_url = "http://test.example/mcp"
+
+ with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ # Mock the HTTP client
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ # Mock successful response
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.headers = {"content-type": "application/json"}
+ mock_response.raise_for_status.return_value = None
+ mock_client.post.return_value = mock_response
+
+ # Test connection
+ try:
+ with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
+ assert read_queue is not None
+ assert write_queue is not None
+ assert get_session_id is not None
+ except Exception:
+ # Connection might fail due to mocking, but we're testing the validation logic
+ pass
+
+
+def test_streamablehttp_client_timeout_configuration():
+ """Test StreamableHTTP client timeout configuration."""
+ test_url = "http://test.example/mcp"
+ custom_headers = {"Authorization": "Bearer test-token"}
+
+ with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ # Mock successful connection
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.headers = {"content-type": "application/json"}
+ mock_response.raise_for_status.return_value = None
+ mock_client.post.return_value = mock_response
+
+ try:
+ with streamablehttp_client(test_url, headers=custom_headers) as (read_queue, write_queue, get_session_id):
+ # Verify the configuration was passed correctly
+ mock_client_factory.assert_called_with(headers=custom_headers)
+ except Exception:
+ # Connection might fail due to mocking, but we tested the configuration
+ pass
+
+
+def test_streamablehttp_client_session_id_handling():
+ """Test StreamableHTTP client properly handles session IDs."""
+ mock_client = MockStreamableHTTPClient("http://test.example/mcp")
+ read_queue, write_queue, get_session_id = mock_client.connect()
+
+ # Test that session ID is available
+ session_id = get_session_id()
+ assert session_id == TEST_SESSION_ID
+
+ # Test that we can use the session ID in subsequent requests
+ assert session_id is not None
+ assert len(session_id) > 0
+
+
+def test_streamablehttp_client_message_parsing():
+ """Test StreamableHTTP client properly parses different message types."""
+ mock_client = MockStreamableHTTPClient("http://test.example/mcp")
+ read_queue, write_queue, get_session_id = mock_client.connect()
+
+ # Test valid initialization response
+ mock_client.send_initialize_response()
+
+ # Should have a SessionMessage in the queue
+ message = read_queue.get(timeout=1.0)
+ assert message is not None
+ assert isinstance(message, types.SessionMessage)
+ assert isinstance(message.message.root, types.JSONRPCResponse)
+
+ # Test tools response
+ mock_client.send_tools_response()
+
+ tools_message = read_queue.get(timeout=1.0)
+ assert tools_message is not None
+ assert isinstance(tools_message, types.SessionMessage)
+
+
+def test_streamablehttp_client_queue_cleanup():
+ """Test that StreamableHTTP client properly cleans up queues on exit."""
+ test_url = "http://test.example/mcp"
+
+ read_queue = None
+ write_queue = None
+
+ with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ # Mock connection that raises an exception
+ mock_client_factory.side_effect = Exception("Connection failed")
+
+ try:
+ with streamablehttp_client(test_url) as (rq, wq, get_session_id):
+ read_queue = rq
+ write_queue = wq
+ except Exception:
+ pass # Expected to fail
+
+ # Queues should be cleaned up even on exception
+ # Note: In real implementation, cleanup should put None to signal shutdown
+
+
+def test_streamablehttp_client_headers_propagation():
+ """Test that custom headers are properly propagated in StreamableHTTP client."""
+ test_url = "http://test.example/mcp"
+ custom_headers = {
+ "Authorization": "Bearer test-token",
+ "X-Custom-Header": "test-value",
+ "User-Agent": "test-client/1.0",
+ }
+
+ with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ # Mock the client factory to capture headers
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.headers = {"content-type": "application/json"}
+ mock_response.raise_for_status.return_value = None
+ mock_client.post.return_value = mock_response
+
+ try:
+ with streamablehttp_client(test_url, headers=custom_headers):
+ pass
+ except Exception:
+ pass # Expected due to mocking
+
+ # Verify headers were passed to client factory
+ # Check that the call was made with headers that include our custom headers
+ mock_client_factory.assert_called_once()
+ call_args = mock_client_factory.call_args
+ assert "headers" in call_args.kwargs
+ passed_headers = call_args.kwargs["headers"]
+
+ # Verify all custom headers are present
+ for key, value in custom_headers.items():
+ assert key in passed_headers
+ assert passed_headers[key] == value
+
+
+def test_streamablehttp_client_concurrent_access():
+ """Test StreamableHTTP client behavior with concurrent queue access."""
+ test_read_queue: queue.Queue = queue.Queue()
+ test_write_queue: queue.Queue = queue.Queue()
+
+ # Simulate concurrent producers and consumers
+ def producer():
+ for i in range(10):
+ test_read_queue.put(f"message_{i}")
+ time.sleep(0.01) # Small delay to simulate real conditions
+
+ def consumer():
+ received = []
+ for _ in range(10):
+ try:
+ msg = test_read_queue.get(timeout=2.0)
+ received.append(msg)
+ except queue.Empty:
+ break
+ return received
+
+ # Start producer in separate thread
+ producer_thread = threading.Thread(target=producer, daemon=True)
+ producer_thread.start()
+
+ # Consume messages
+ received_messages = consumer()
+
+ # Wait for producer to finish
+ producer_thread.join(timeout=5.0)
+
+ # Verify all messages were received
+ assert len(received_messages) == 10
+ for i in range(10):
+ assert f"message_{i}" in received_messages
+
+
+def test_streamablehttp_client_json_vs_sse_mode():
+ """Test StreamableHTTP client handling of JSON vs SSE response modes."""
+ test_url = "http://test.example/mcp"
+
+ with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ # Mock JSON response
+ mock_json_response = Mock()
+ mock_json_response.status_code = 200
+ mock_json_response.headers = {"content-type": "application/json"}
+ mock_json_response.json.return_value = {"result": "json_mode"}
+ mock_json_response.raise_for_status.return_value = None
+
+ # Mock SSE response
+ mock_sse_response = Mock()
+ mock_sse_response.status_code = 200
+ mock_sse_response.headers = {"content-type": "text/event-stream"}
+ mock_sse_response.raise_for_status.return_value = None
+
+ # Test JSON mode
+ mock_client.post.return_value = mock_json_response
+
+ try:
+ with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
+ # Should handle JSON responses
+ assert read_queue is not None
+ assert write_queue is not None
+ except Exception:
+ pass # Expected due to mocking
+
+ # Test SSE mode
+ mock_client.post.return_value = mock_sse_response
+
+ try:
+ with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
+ # Should handle SSE responses
+ assert read_queue is not None
+ assert write_queue is not None
+ except Exception:
+ pass # Expected due to mocking
+
+
+def test_streamablehttp_client_terminate_on_close():
+ """Test StreamableHTTP client terminate_on_close parameter."""
+ test_url = "http://test.example/mcp"
+
+ with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.headers = {"content-type": "application/json"}
+ mock_response.raise_for_status.return_value = None
+ mock_client.post.return_value = mock_response
+ mock_client.delete.return_value = mock_response
+
+ # Test with terminate_on_close=True (default)
+ try:
+ with streamablehttp_client(test_url, terminate_on_close=True) as (read_queue, write_queue, get_session_id):
+ pass
+ except Exception:
+ pass # Expected due to mocking
+
+ # Test with terminate_on_close=False
+ try:
+ with streamablehttp_client(test_url, terminate_on_close=False) as (read_queue, write_queue, get_session_id):
+ pass
+ except Exception:
+ pass # Expected due to mocking
+
+
+def test_streamablehttp_client_protocol_version_handling():
+ """Test StreamableHTTP client protocol version handling."""
+ mock_client = MockStreamableHTTPClient("http://test.example/mcp")
+ read_queue, write_queue, get_session_id = mock_client.connect()
+
+ # Send initialize response with specific protocol version
+
+ session_message = types.SessionMessage(
+ message=types.JSONRPCMessage(
+ root=types.JSONRPCResponse(
+ jsonrpc="2.0",
+ id="init-1",
+ result={
+ "protocolVersion": "2024-11-05",
+ "capabilities": {},
+ "serverInfo": {"name": SERVER_NAME, "version": "0.1.0"},
+ },
+ )
+ )
+ )
+ read_queue.put(session_message)
+
+ # Get the message and verify protocol version
+ message = read_queue.get(timeout=1.0)
+ assert message is not None
+ assert isinstance(message.message.root, types.JSONRPCResponse)
+ result = message.message.root.result
+ assert result["protocolVersion"] == "2024-11-05"
+
+
+def test_streamablehttp_client_error_response_handling():
+ """Test StreamableHTTP client handling of error responses."""
+ mock_client = MockStreamableHTTPClient("http://test.example/mcp")
+ read_queue, write_queue, get_session_id = mock_client.connect()
+
+ # Send an error response
+ session_message = types.SessionMessage(
+ message=types.JSONRPCMessage(
+ root=types.JSONRPCError(
+ jsonrpc="2.0",
+ id="test-1",
+ error=types.ErrorData(code=-32601, message="Method not found", data=None),
+ )
+ )
+ )
+ read_queue.put(session_message)
+
+ # Get the error message
+ message = read_queue.get(timeout=1.0)
+ assert message is not None
+ assert isinstance(message.message.root, types.JSONRPCError)
+ assert message.message.root.error.code == -32601
+ assert message.message.root.error.message == "Method not found"
+
+
+def test_streamablehttp_client_resumption_token_handling():
+ """Test StreamableHTTP client resumption token functionality."""
+ test_url = "http://test.example/mcp"
+ test_resumption_token = "resume-token-123"
+
+ with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
+ mock_client = Mock()
+ mock_client_factory.return_value.__enter__.return_value = mock_client
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.headers = {"content-type": "application/json", "last-event-id": test_resumption_token}
+ mock_response.raise_for_status.return_value = None
+ mock_client.post.return_value = mock_response
+
+ try:
+ with streamablehttp_client(test_url) as (read_queue, write_queue, get_session_id):
+ # Test that resumption token can be captured from headers
+ assert read_queue is not None
+ assert write_queue is not None
+ except Exception:
+ pass # Expected due to mocking
diff --git a/api/tests/unit_tests/core/ops/__init__.py b/api/tests/unit_tests/core/ops/__init__.py
new file mode 100644
index 0000000000..bb92ccdec7
--- /dev/null
+++ b/api/tests/unit_tests/core/ops/__init__.py
@@ -0,0 +1 @@
+# Unit tests for core ops module
diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py
new file mode 100644
index 0000000000..81cb04548d
--- /dev/null
+++ b/api/tests/unit_tests/core/ops/test_config_entity.py
@@ -0,0 +1,385 @@
+import pytest
+from pydantic import ValidationError
+
+from core.ops.entities.config_entity import (
+ AliyunConfig,
+ ArizeConfig,
+ LangfuseConfig,
+ LangSmithConfig,
+ OpikConfig,
+ PhoenixConfig,
+ TracingProviderEnum,
+ WeaveConfig,
+)
+
+
+class TestTracingProviderEnum:
+ """Test cases for TracingProviderEnum"""
+
+ def test_enum_values(self):
+ """Test that all expected enum values are present"""
+ assert TracingProviderEnum.ARIZE == "arize"
+ assert TracingProviderEnum.PHOENIX == "phoenix"
+ assert TracingProviderEnum.LANGFUSE == "langfuse"
+ assert TracingProviderEnum.LANGSMITH == "langsmith"
+ assert TracingProviderEnum.OPIK == "opik"
+ assert TracingProviderEnum.WEAVE == "weave"
+ assert TracingProviderEnum.ALIYUN == "aliyun"
+
+
+class TestArizeConfig:
+ """Test cases for ArizeConfig"""
+
+ def test_valid_config(self):
+ """Test valid Arize configuration"""
+ config = ArizeConfig(
+ api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com"
+ )
+ assert config.api_key == "test_key"
+ assert config.space_id == "test_space"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.arize.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = ArizeConfig()
+ assert config.api_key is None
+ assert config.space_id is None
+ assert config.project is None
+ assert config.endpoint == "https://otlp.arize.com"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = ArizeConfig(project="")
+ assert config.project == "default"
+
+ def test_project_validation_none(self):
+ """Test project validation with None value"""
+ config = ArizeConfig(project=None)
+ assert config.project == "default"
+
+ def test_endpoint_validation_empty(self):
+ """Test endpoint validation with empty value"""
+ config = ArizeConfig(endpoint="")
+ assert config.endpoint == "https://otlp.arize.com"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation normalizes URL by removing path"""
+ config = ArizeConfig(endpoint="https://custom.arize.com/api/v1")
+ assert config.endpoint == "https://custom.arize.com"
+
+ def test_endpoint_validation_invalid_scheme(self):
+ """Test endpoint validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ ArizeConfig(endpoint="ftp://invalid.com")
+
+ def test_endpoint_validation_no_scheme(self):
+ """Test endpoint validation rejects URLs without scheme"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ ArizeConfig(endpoint="invalid.com")
+
+
+class TestPhoenixConfig:
+ """Test cases for PhoenixConfig"""
+
+ def test_valid_config(self):
+ """Test valid Phoenix configuration"""
+ config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com")
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.phoenix.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = PhoenixConfig()
+ assert config.api_key is None
+ assert config.project is None
+ assert config.endpoint == "https://app.phoenix.arize.com"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = PhoenixConfig(project="")
+ assert config.project == "default"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation normalizes URL by removing path"""
+ config = PhoenixConfig(endpoint="https://custom.phoenix.com/api/v1")
+ assert config.endpoint == "https://custom.phoenix.com"
+
+
+class TestLangfuseConfig:
+ """Test cases for LangfuseConfig"""
+
+ def test_valid_config(self):
+ """Test valid Langfuse configuration"""
+ config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com")
+ assert config.public_key == "public_key"
+ assert config.secret_key == "secret_key"
+ assert config.host == "https://custom.langfuse.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = LangfuseConfig(public_key="public", secret_key="secret")
+ assert config.host == "https://api.langfuse.com"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ LangfuseConfig()
+
+ with pytest.raises(ValidationError):
+ LangfuseConfig(public_key="public")
+
+ with pytest.raises(ValidationError):
+ LangfuseConfig(secret_key="secret")
+
+ def test_host_validation_empty(self):
+ """Test host validation with empty value"""
+ config = LangfuseConfig(public_key="public", secret_key="secret", host="")
+ assert config.host == "https://api.langfuse.com"
+
+
+class TestLangSmithConfig:
+ """Test cases for LangSmithConfig"""
+
+ def test_valid_config(self):
+ """Test valid LangSmith configuration"""
+ config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com")
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.smith.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = LangSmithConfig(api_key="key", project="project")
+ assert config.endpoint == "https://api.smith.langchain.com"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ LangSmithConfig()
+
+ with pytest.raises(ValidationError):
+ LangSmithConfig(api_key="key")
+
+ with pytest.raises(ValidationError):
+ LangSmithConfig(project="project")
+
+ def test_endpoint_validation_https_only(self):
+ """Test endpoint validation only allows HTTPS"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com")
+
+
+class TestOpikConfig:
+ """Test cases for OpikConfig"""
+
+ def test_valid_config(self):
+ """Test valid Opik configuration"""
+ config = OpikConfig(
+ api_key="test_key",
+ project="test_project",
+ workspace="test_workspace",
+ url="https://custom.comet.com/opik/api/",
+ )
+ assert config.api_key == "test_key"
+ assert config.project == "test_project"
+ assert config.workspace == "test_workspace"
+ assert config.url == "https://custom.comet.com/opik/api/"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = OpikConfig()
+ assert config.api_key is None
+ assert config.project is None
+ assert config.workspace is None
+ assert config.url == "https://www.comet.com/opik/api/"
+
+ def test_project_validation_empty(self):
+ """Test project validation with empty value"""
+ config = OpikConfig(project="")
+ assert config.project == "Default Project"
+
+ def test_url_validation_empty(self):
+ """Test URL validation with empty value"""
+ config = OpikConfig(url="")
+ assert config.url == "https://www.comet.com/opik/api/"
+
+ def test_url_validation_missing_suffix(self):
+ """Test URL validation requires /api/ suffix"""
+ with pytest.raises(ValidationError, match="URL should end with /api/"):
+ OpikConfig(url="https://custom.comet.com/opik/")
+
+ def test_url_validation_invalid_scheme(self):
+ """Test URL validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL must start with https:// or http://"):
+ OpikConfig(url="ftp://custom.comet.com/opik/api/")
+
+
+class TestWeaveConfig:
+ """Test cases for WeaveConfig"""
+
+ def test_valid_config(self):
+ """Test valid Weave configuration"""
+ config = WeaveConfig(
+ api_key="test_key",
+ entity="test_entity",
+ project="test_project",
+ endpoint="https://custom.wandb.ai",
+ host="https://custom.host.com",
+ )
+ assert config.api_key == "test_key"
+ assert config.entity == "test_entity"
+ assert config.project == "test_project"
+ assert config.endpoint == "https://custom.wandb.ai"
+ assert config.host == "https://custom.host.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = WeaveConfig(api_key="key", project="project")
+ assert config.entity is None
+ assert config.endpoint == "https://trace.wandb.ai"
+ assert config.host is None
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ WeaveConfig()
+
+ with pytest.raises(ValidationError):
+ WeaveConfig(api_key="key")
+
+ with pytest.raises(ValidationError):
+ WeaveConfig(project="project")
+
+ def test_endpoint_validation_https_only(self):
+ """Test endpoint validation only allows HTTPS"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
+
+ def test_host_validation_optional(self):
+ """Test host validation is optional but validates when provided"""
+ config = WeaveConfig(api_key="key", project="project", host=None)
+ assert config.host is None
+
+ config = WeaveConfig(api_key="key", project="project", host="")
+ assert config.host == ""
+
+ config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
+ assert config.host == "https://valid.host.com"
+
+ def test_host_validation_invalid_scheme(self):
+ """Test host validation rejects invalid schemes when provided"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
+
+
+class TestAliyunConfig:
+ """Test cases for AliyunConfig"""
+
+ def test_valid_config(self):
+ """Test valid Aliyun configuration"""
+ config = AliyunConfig(
+ app_name="test_app",
+ license_key="test_license_key",
+ endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com",
+ )
+ assert config.app_name == "test_app"
+ assert config.license_key == "test_license_key"
+ assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com"
+
+ def test_default_values(self):
+ """Test default values are set correctly"""
+ config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+ assert config.app_name == "dify_app"
+
+ def test_missing_required_fields(self):
+ """Test that required fields are enforced"""
+ with pytest.raises(ValidationError):
+ AliyunConfig()
+
+ with pytest.raises(ValidationError):
+ AliyunConfig(license_key="test_license")
+
+ with pytest.raises(ValidationError):
+ AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_app_name_validation_empty(self):
+ """Test app_name validation with empty value"""
+ config = AliyunConfig(
+ license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
+ )
+ assert config.app_name == "dify_app"
+
+ def test_endpoint_validation_empty(self):
+ """Test endpoint validation with empty value"""
+ config = AliyunConfig(license_key="test_license", endpoint="")
+ assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
+
+ def test_endpoint_validation_with_path(self):
+ """Test endpoint validation normalizes URL by removing path"""
+ config = AliyunConfig(
+ license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
+ )
+ assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
+
+ def test_endpoint_validation_invalid_scheme(self):
+ """Test endpoint validation rejects invalid schemes"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_endpoint_validation_no_scheme(self):
+ """Test endpoint validation rejects URLs without scheme"""
+ with pytest.raises(ValidationError, match="URL scheme must be one of"):
+ AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com")
+
+ def test_license_key_required(self):
+ """Test that license_key is required and cannot be empty"""
+ with pytest.raises(ValidationError):
+ AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com")
+
+
+class TestConfigIntegration:
+ """Integration tests for configuration classes"""
+
+ def test_all_configs_can_be_instantiated(self):
+ """Test that all config classes can be instantiated with valid data"""
+ configs = [
+ ArizeConfig(api_key="key"),
+ PhoenixConfig(api_key="key"),
+ LangfuseConfig(public_key="public", secret_key="secret"),
+ LangSmithConfig(api_key="key", project="project"),
+ OpikConfig(api_key="key"),
+ WeaveConfig(api_key="key", project="project"),
+ AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
+ ]
+
+ for config in configs:
+ assert config is not None
+
+ def test_url_normalization_consistency(self):
+ """Test that URL normalization works consistently across configs"""
+ # Test that paths are removed from endpoints
+ arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test")
+ phoenix_config = PhoenixConfig(endpoint="https://phoenix.com/api/v2/")
+ aliyun_config = AliyunConfig(
+ license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces"
+ )
+
+ assert arize_config.endpoint == "https://arize.com"
+ assert phoenix_config.endpoint == "https://phoenix.com"
+ assert aliyun_config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com"
+
+ def test_project_default_values(self):
+ """Test that project default values are set correctly"""
+ arize_config = ArizeConfig(project="")
+ phoenix_config = PhoenixConfig(project="")
+ opik_config = OpikConfig(project="")
+ aliyun_config = AliyunConfig(
+ license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name=""
+ )
+
+ assert arize_config.project == "default"
+ assert phoenix_config.project == "default"
+ assert opik_config.project == "Default Project"
+ assert aliyun_config.app_name == "dify_app"
diff --git a/api/tests/unit_tests/core/ops/test_utils.py b/api/tests/unit_tests/core/ops/test_utils.py
new file mode 100644
index 0000000000..7cc2772acf
--- /dev/null
+++ b/api/tests/unit_tests/core/ops/test_utils.py
@@ -0,0 +1,138 @@
+import pytest
+
+from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
+
+
+class TestValidateUrl:
+ """Test cases for validate_url function"""
+
+ def test_valid_https_url(self):
+ """Test valid HTTPS URL"""
+ result = validate_url("https://example.com", "https://default.com")
+ assert result == "https://example.com"
+
+ def test_valid_http_url(self):
+ """Test valid HTTP URL"""
+ result = validate_url("http://example.com", "https://default.com")
+ assert result == "http://example.com"
+
+ def test_url_with_path_removed(self):
+ """Test that URL path is removed during normalization"""
+ result = validate_url("https://example.com/api/v1/test", "https://default.com")
+ assert result == "https://example.com"
+
+ def test_url_with_query_removed(self):
+ """Test that URL query parameters are removed"""
+ result = validate_url("https://example.com?param=value", "https://default.com")
+ assert result == "https://example.com"
+
+ def test_url_with_fragment_removed(self):
+ """Test that URL fragments are removed"""
+ result = validate_url("https://example.com#section", "https://default.com")
+ assert result == "https://example.com"
+
+ def test_empty_url_returns_default(self):
+ """Test empty URL returns default"""
+ result = validate_url("", "https://default.com")
+ assert result == "https://default.com"
+
+ def test_none_url_returns_default(self):
+ """Test None URL returns default"""
+ result = validate_url(None, "https://default.com")
+ assert result == "https://default.com"
+
+ def test_whitespace_url_returns_default(self):
+ """Test whitespace URL returns default"""
+ result = validate_url(" ", "https://default.com")
+ assert result == "https://default.com"
+
+ def test_invalid_scheme_raises_error(self):
+ """Test invalid scheme raises ValueError"""
+ with pytest.raises(ValueError, match="URL scheme must be one of"):
+ validate_url("ftp://example.com", "https://default.com")
+
+ def test_no_scheme_raises_error(self):
+ """Test URL without scheme raises ValueError"""
+ with pytest.raises(ValueError, match="URL scheme must be one of"):
+ validate_url("example.com", "https://default.com")
+
+ def test_custom_allowed_schemes(self):
+ """Test custom allowed schemes"""
+ result = validate_url("https://example.com", "https://default.com", allowed_schemes=("https",))
+ assert result == "https://example.com"
+
+ with pytest.raises(ValueError, match="URL scheme must be one of"):
+ validate_url("http://example.com", "https://default.com", allowed_schemes=("https",))
+
+
+class TestValidateUrlWithPath:
+ """Test cases for validate_url_with_path function"""
+
+ def test_valid_url_with_path(self):
+ """Test valid URL with path"""
+ result = validate_url_with_path("https://example.com/api/v1", "https://default.com")
+ assert result == "https://example.com/api/v1"
+
+ def test_valid_url_with_required_suffix(self):
+ """Test valid URL with required suffix"""
+ result = validate_url_with_path("https://example.com/api/", "https://default.com", required_suffix="/api/")
+ assert result == "https://example.com/api/"
+
+ def test_url_without_required_suffix_raises_error(self):
+ """Test URL without required suffix raises error"""
+ with pytest.raises(ValueError, match="URL should end with /api/"):
+ validate_url_with_path("https://example.com/api", "https://default.com", required_suffix="/api/")
+
+ def test_empty_url_returns_default(self):
+ """Test empty URL returns default"""
+ result = validate_url_with_path("", "https://default.com")
+ assert result == "https://default.com"
+
+ def test_none_url_returns_default(self):
+ """Test None URL returns default"""
+ result = validate_url_with_path(None, "https://default.com")
+ assert result == "https://default.com"
+
+ def test_invalid_scheme_raises_error(self):
+ """Test invalid scheme raises ValueError"""
+ with pytest.raises(ValueError, match="URL must start with https:// or http://"):
+ validate_url_with_path("ftp://example.com", "https://default.com")
+
+ def test_no_scheme_raises_error(self):
+ """Test URL without scheme raises ValueError"""
+ with pytest.raises(ValueError, match="URL must start with https:// or http://"):
+ validate_url_with_path("example.com", "https://default.com")
+
+
+class TestValidateProjectName:
+ """Test cases for validate_project_name function"""
+
+ def test_valid_project_name(self):
+ """Test valid project name"""
+ result = validate_project_name("my-project", "default")
+ assert result == "my-project"
+
+ def test_empty_project_name_returns_default(self):
+ """Test empty project name returns default"""
+ result = validate_project_name("", "default")
+ assert result == "default"
+
+ def test_none_project_name_returns_default(self):
+ """Test None project name returns default"""
+ result = validate_project_name(None, "default")
+ assert result == "default"
+
+ def test_whitespace_project_name_returns_default(self):
+ """Test whitespace project name returns default"""
+ result = validate_project_name(" ", "default")
+ assert result == "default"
+
+ def test_project_name_with_whitespace_trimmed(self):
+ """Test project name with whitespace is trimmed"""
+ result = validate_project_name(" my-project ", "default")
+ assert result == "my-project"
+
+ def test_custom_default_name(self):
+ """Test custom default name"""
+ result = validate_project_name("", "Custom Default")
+ assert result == "Custom Default"
diff --git a/api/tests/unit_tests/core/repositories/__init__.py b/api/tests/unit_tests/core/repositories/__init__.py
new file mode 100644
index 0000000000..c65d7da61d
--- /dev/null
+++ b/api/tests/unit_tests/core/repositories/__init__.py
@@ -0,0 +1 @@
+# Unit tests for core repositories module
diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py
new file mode 100644
index 0000000000..fce4a6fb6b
--- /dev/null
+++ b/api/tests/unit_tests/core/repositories/test_factory.py
@@ -0,0 +1,455 @@
+"""
+Unit tests for the RepositoryFactory.
+
+This module tests the factory pattern implementation for creating repository instances
+based on configuration, including error handling and validation.
+"""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pytest_mock import MockerFixture
+from sqlalchemy.engine import Engine
+from sqlalchemy.orm import sessionmaker
+
+from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
+from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from models import Account, EndUser
+from models.enums import WorkflowRunTriggeredFrom
+from models.workflow import WorkflowNodeExecutionTriggeredFrom
+
+
+class TestRepositoryFactory:
+ """Test cases for RepositoryFactory."""
+
+ def test_import_class_success(self):
+ """Test successful class import."""
+ # Test importing a real class
+ class_path = "unittest.mock.MagicMock"
+ result = DifyCoreRepositoryFactory._import_class(class_path)
+ assert result is MagicMock
+
+ def test_import_class_invalid_path(self):
+ """Test import with invalid module path."""
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._import_class("invalid.module.path")
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ def test_import_class_invalid_class_name(self):
+ """Test import with invalid class name."""
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ def test_import_class_malformed_path(self):
+ """Test import with malformed path (no dots)."""
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._import_class("invalidpath")
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ def test_validate_repository_interface_success(self):
+ """Test successful interface validation."""
+
+ # Create a mock class that implements the required methods
+ class MockRepository:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ # Create a mock interface with the same methods
+ class MockInterface:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ # Should not raise an exception
+ DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
+
+ def test_validate_repository_interface_missing_methods(self):
+ """Test interface validation with missing methods."""
+
+ # Create a mock class that doesn't implement all required methods
+ class IncompleteRepository:
+ def save(self):
+ pass
+
+ # Missing get_by_id method
+
+ # Create a mock interface with required methods
+ class MockInterface:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
+ assert "does not implement required methods" in str(exc_info.value)
+ assert "get_by_id" in str(exc_info.value)
+
+ def test_validate_constructor_signature_success(self):
+ """Test successful constructor signature validation."""
+
+ class MockRepository:
+ def __init__(self, session_factory, user, app_id, triggered_from):
+ pass
+
+ # Should not raise an exception
+ DifyCoreRepositoryFactory._validate_constructor_signature(
+ MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+
+ def test_validate_constructor_signature_missing_params(self):
+ """Test constructor validation with missing parameters."""
+
+ class IncompleteRepository:
+ def __init__(self, session_factory, user):
+ # Missing app_id and triggered_from parameters
+ pass
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._validate_constructor_signature(
+ IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+ assert "does not accept required parameters" in str(exc_info.value)
+ assert "app_id" in str(exc_info.value)
+ assert "triggered_from" in str(exc_info.value)
+
+ def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
+ """Test constructor validation when inspection fails."""
+ # Mock inspect.signature to raise an exception
+ mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
+
+ class MockRepository:
+ def __init__(self, session_factory):
+ pass
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
+ assert "Failed to validate constructor signature" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
+ """Test successful creation of WorkflowExecutionRepository."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ # Create mock dependencies
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=Account)
+ app_id = "test-app-id"
+ triggered_from = WorkflowRunTriggeredFrom.APP_RUN
+
+ # Mock the imported class to be a valid repository
+ mock_repository_class = MagicMock()
+ mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
+ mock_repository_class.return_value = mock_repository_instance
+
+ # Mock the validation methods
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+
+ # Verify the repository was created with correct parameters
+ mock_repository_class.assert_called_once_with(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+ assert result is mock_repository_instance
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_execution_repository_import_error(self, mock_config):
+ """Test WorkflowExecutionRepository creation with import error."""
+ # Setup mock configuration with invalid class path
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=Account)
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
+ """Test WorkflowExecutionRepository creation with validation error."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=Account)
+
+ # Mock import to succeed but validation to fail
+ mock_repository_class = MagicMock()
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(
+ DifyCoreRepositoryFactory,
+ "_validate_repository_interface",
+ side_effect=RepositoryImportError("Interface validation failed"),
+ ),
+ ):
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+ assert "Interface validation failed" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
+ """Test WorkflowExecutionRepository creation with instantiation error."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=Account)
+
+ # Mock import and validation to succeed but instantiation to fail
+ mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+ assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
+ """Test successful creation of WorkflowNodeExecutionRepository."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ # Create mock dependencies
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=EndUser)
+ app_id = "test-app-id"
+ triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
+
+ # Mock the imported class to be a valid repository
+ mock_repository_class = MagicMock()
+ mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
+ mock_repository_class.return_value = mock_repository_instance
+
+ # Mock the validation methods
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+
+ # Verify the repository was created with correct parameters
+ mock_repository_class.assert_called_once_with(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+ assert result is mock_repository_instance
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_node_execution_repository_import_error(self, mock_config):
+ """Test WorkflowNodeExecutionRepository creation with import error."""
+ # Setup mock configuration with invalid class path
+ mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=EndUser)
+
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+ assert "Cannot import repository class" in str(exc_info.value)
+
+ def test_repository_import_error_exception(self):
+ """Test RepositoryImportError exception."""
+ error_message = "Test error message"
+ exception = RepositoryImportError(error_message)
+ assert str(exception) == error_message
+ assert isinstance(exception, Exception)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
+ """Test repository creation with Engine instead of sessionmaker."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ # Create mock dependencies with Engine instead of sessionmaker
+ mock_engine = MagicMock(spec=Engine)
+ mock_user = MagicMock(spec=Account)
+
+ # Mock the imported class to be a valid repository
+ mock_repository_class = MagicMock()
+ mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
+ mock_repository_class.return_value = mock_repository_instance
+
+ # Mock the validation methods
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
+ session_factory=mock_engine, # Using Engine instead of sessionmaker
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+
+ # Verify the repository was created with the Engine
+ mock_repository_class.assert_called_once_with(
+ session_factory=mock_engine,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
+ )
+ assert result is mock_repository_instance
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
+ """Test WorkflowNodeExecutionRepository creation with validation error."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=EndUser)
+
+ # Mock import to succeed but validation to fail
+ mock_repository_class = MagicMock()
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(
+ DifyCoreRepositoryFactory,
+ "_validate_repository_interface",
+ side_effect=RepositoryImportError("Interface validation failed"),
+ ),
+ ):
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+ assert "Interface validation failed" in str(exc_info.value)
+
+ @patch("core.repositories.factory.dify_config")
+ def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
+ """Test WorkflowNodeExecutionRepository creation with instantiation error."""
+ # Setup mock configuration
+ mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
+
+ mock_session_factory = MagicMock(spec=sessionmaker)
+ mock_user = MagicMock(spec=EndUser)
+
+ # Mock import and validation to succeed but instantiation to fail
+ mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
+ with (
+ patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
+ patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
+ patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
+ ):
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
+ session_factory=mock_session_factory,
+ user=mock_user,
+ app_id="test-app-id",
+ triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+ assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
+
+ def test_validate_repository_interface_with_private_methods(self):
+ """Test interface validation ignores private methods."""
+
+ # Create a mock class with private methods
+ class MockRepository:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ def _private_method(self):
+ pass
+
+ # Create a mock interface with private methods
+ class MockInterface:
+ def save(self):
+ pass
+
+ def get_by_id(self):
+ pass
+
+ def _private_method(self):
+ pass
+
+ # Should not raise an exception (private methods are ignored)
+ DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
+
+ def test_validate_constructor_signature_with_extra_params(self):
+ """Test constructor validation with extra parameters (should pass)."""
+
+ class MockRepository:
+ def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
+ pass
+
+ # Should not raise an exception (extra parameters are allowed)
+ DifyCoreRepositoryFactory._validate_constructor_signature(
+ MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+
+ def test_validate_constructor_signature_with_kwargs(self):
+ """Test constructor validation with **kwargs (current implementation doesn't support this)."""
+
+ class MockRepository:
+ def __init__(self, session_factory, user, **kwargs):
+ pass
+
+ # Current implementation doesn't handle **kwargs, so this should raise an exception
+ with pytest.raises(RepositoryImportError) as exc_info:
+ DifyCoreRepositoryFactory._validate_constructor_signature(
+ MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+ assert "does not accept required parameters" in str(exc_info.value)
+ assert "app_id" in str(exc_info.value)
+ assert "triggered_from" in str(exc_info.value)
diff --git a/api/tests/unit_tests/core/tools/utils/__init__.py b/api/tests/unit_tests/core/tools/utils/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py
new file mode 100644
index 0000000000..8e07293ce0
--- /dev/null
+++ b/api/tests/unit_tests/core/tools/utils/test_parser.py
@@ -0,0 +1,56 @@
+import pytest
+from flask import Flask
+
+from core.tools.utils.parser import ApiBasedToolSchemaParser
+
+
+@pytest.fixture
+def app():
+ app = Flask(__name__)
+ return app
+
+
+def test_parse_openapi_to_tool_bundle_operation_id(app):
+ openapi = {
+ "openapi": "3.0.0",
+ "info": {"title": "Simple API", "version": "1.0.0"},
+ "servers": [{"url": "http://localhost:3000"}],
+ "paths": {
+ "/": {
+ "get": {
+ "summary": "Root endpoint",
+ "responses": {
+ "200": {
+ "description": "Successful response",
+ }
+ },
+ }
+ },
+ "/api/resources": {
+ "get": {
+ "summary": "Non-root endpoint without an operationId",
+ "responses": {
+ "200": {
+ "description": "Successful response",
+ }
+ },
+ },
+ "post": {
+ "summary": "Non-root endpoint with an operationId",
+ "operationId": "createResource",
+ "responses": {
+ "201": {
+ "description": "Resource created",
+ }
+ },
+ },
+ },
+ },
+ }
+ with app.test_request_context():
+ tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
+
+ assert len(tool_bundles) == 3
+ assert tool_bundles[0].operation_id == "_get"
+ assert tool_bundles[1].operation_id == "apiresources_get"
+ assert tool_bundles[2].operation_id == "createResource"
diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py
index 1b035d01a7..4c8d983d20 100644
--- a/api/tests/unit_tests/core/variables/test_segment.py
+++ b/api/tests/unit_tests/core/variables/test_segment.py
@@ -1,14 +1,49 @@
+import dataclasses
+
+from pydantic import BaseModel
+
+from core.file import File, FileTransferMethod, FileType
from core.helper import encrypter
-from core.variables import SecretVariable, StringVariable
+from core.variables.segments import (
+ ArrayAnySegment,
+ ArrayFileSegment,
+ ArrayNumberSegment,
+ ArrayObjectSegment,
+ ArrayStringSegment,
+ FileSegment,
+ FloatSegment,
+ IntegerSegment,
+ NoneSegment,
+ ObjectSegment,
+ Segment,
+ SegmentUnion,
+ StringSegment,
+ get_segment_discriminator,
+)
+from core.variables.types import SegmentType
+from core.variables.variables import (
+ ArrayAnyVariable,
+ ArrayFileVariable,
+ ArrayNumberVariable,
+ ArrayObjectVariable,
+ ArrayStringVariable,
+ FileVariable,
+ FloatVariable,
+ IntegerVariable,
+ NoneVariable,
+ ObjectVariable,
+ SecretVariable,
+ StringVariable,
+ Variable,
+ VariableUnion,
+)
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
def test_segment_group_to_text():
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey("user_id"): "fake-user-id",
- },
+ system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
@@ -30,7 +65,7 @@ def test_segment_group_to_text():
def test_convert_constant_to_segment_group():
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey("user_id"): "fake-user-id",
- },
+ system_variables=SystemVariable(user_id="fake-user-id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
assert segments_group.log == "fake-user-id"
assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id"
+
+
+class _Segments(BaseModel):
+ segments: list[SegmentUnion]
+
+
+class _Variables(BaseModel):
+ variables: list[VariableUnion]
+
+
+def create_test_file(
+ file_type: FileType = FileType.DOCUMENT,
+ transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
+ filename: str = "test.txt",
+ extension: str = ".txt",
+ mime_type: str = "text/plain",
+ size: int = 1024,
+) -> File:
+ """Factory function to create File objects for testing"""
+ return File(
+ tenant_id="test-tenant",
+ type=file_type,
+ transfer_method=transfer_method,
+ filename=filename,
+ extension=extension,
+ mime_type=mime_type,
+ size=size,
+ related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
+ remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
+ storage_key="test-storage-key",
+ )
+
+
+class TestSegmentDumpAndLoad:
+ """Test suite for segment and variable serialization/deserialization"""
+
+ def test_segments(self):
+ """Test basic segment serialization compatibility"""
+ model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
+ json = model.model_dump_json()
+ print("Json: ", json)
+ loaded = _Segments.model_validate_json(json)
+ assert loaded == model
+
+ def test_segment_number(self):
+ """Test number segment serialization compatibility"""
+ model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
+ json = model.model_dump_json()
+ print("Json: ", json)
+ loaded = _Segments.model_validate_json(json)
+ assert loaded == model
+
+ def test_variables(self):
+ """Test variable serialization compatibility"""
+ model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
+ json = model.model_dump_json()
+ print("Json: ", json)
+ restored = _Variables.model_validate_json(json)
+ assert restored == model
+
+ def test_all_segments_serialization(self):
+ """Test serialization/deserialization of all segment types"""
+ # Create one instance of each segment type
+ test_file = create_test_file()
+
+ all_segments: list[SegmentUnion] = [
+ NoneSegment(),
+ StringSegment(value="test string"),
+ IntegerSegment(value=42),
+ FloatSegment(value=3.14),
+ ObjectSegment(value={"key": "value", "number": 123}),
+ FileSegment(value=test_file),
+ ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
+ ArrayStringSegment(value=["hello", "world"]),
+ ArrayNumberSegment(value=[1, 2.5, 3]),
+ ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
+ ArrayFileSegment(value=[]), # Empty array to avoid file complexity
+ ]
+
+ # Test serialization and deserialization
+ model = _Segments(segments=all_segments)
+ json_str = model.model_dump_json()
+ loaded = _Segments.model_validate_json(json_str)
+
+ # Verify all segments are preserved
+ assert len(loaded.segments) == len(all_segments)
+
+ for original, loaded_segment in zip(all_segments, loaded.segments):
+ assert type(loaded_segment) == type(original)
+ assert loaded_segment.value_type == original.value_type
+
+ # For file segments, compare key properties instead of exact equality
+ if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
+ orig_file = original.value
+ loaded_file = loaded_segment.value
+ assert isinstance(orig_file, File)
+ assert isinstance(loaded_file, File)
+ assert loaded_file.tenant_id == orig_file.tenant_id
+ assert loaded_file.type == orig_file.type
+ assert loaded_file.filename == orig_file.filename
+ else:
+ assert loaded_segment.value == original.value
+
+ def test_all_variables_serialization(self):
+ """Test serialization/deserialization of all variable types"""
+ # Create one instance of each variable type
+ test_file = create_test_file()
+
+ all_variables: list[VariableUnion] = [
+ NoneVariable(name="none_var"),
+ StringVariable(value="test string", name="string_var"),
+ IntegerVariable(value=42, name="int_var"),
+ FloatVariable(value=3.14, name="float_var"),
+ ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
+ FileVariable(value=test_file, name="file_var"),
+ ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
+ ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
+ ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
+ ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
+ ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
+ ]
+
+ # Test serialization and deserialization
+ model = _Variables(variables=all_variables)
+ json_str = model.model_dump_json()
+ loaded = _Variables.model_validate_json(json_str)
+
+ # Verify all variables are preserved
+ assert len(loaded.variables) == len(all_variables)
+
+ for original, loaded_variable in zip(all_variables, loaded.variables):
+ assert type(loaded_variable) == type(original)
+ assert loaded_variable.value_type == original.value_type
+ assert loaded_variable.name == original.name
+
+ # For file variables, compare key properties instead of exact equality
+ if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
+ orig_file = original.value
+ loaded_file = loaded_variable.value
+ assert isinstance(orig_file, File)
+ assert isinstance(loaded_file, File)
+ assert loaded_file.tenant_id == orig_file.tenant_id
+ assert loaded_file.type == orig_file.type
+ assert loaded_file.filename == orig_file.filename
+ else:
+ assert loaded_variable.value == original.value
+
+ def test_segment_discriminator_function_for_segment_types(self):
+ """Test the segment discriminator function"""
+
+ @dataclasses.dataclass
+ class TestCase:
+ segment: Segment
+ expected_segment_type: SegmentType
+
+ file1 = create_test_file()
+ file2 = create_test_file(filename="test2.txt")
+
+ cases = [
+ TestCase(
+ NoneSegment(),
+ SegmentType.NONE,
+ ),
+ TestCase(
+ StringSegment(value=""),
+ SegmentType.STRING,
+ ),
+ TestCase(
+ FloatSegment(value=0.0),
+ SegmentType.FLOAT,
+ ),
+ TestCase(
+ IntegerSegment(value=0),
+ SegmentType.INTEGER,
+ ),
+ TestCase(
+ ObjectSegment(value={}),
+ SegmentType.OBJECT,
+ ),
+ TestCase(
+ FileSegment(value=file1),
+ SegmentType.FILE,
+ ),
+ TestCase(
+ ArrayAnySegment(value=[0, 0.0, ""]),
+ SegmentType.ARRAY_ANY,
+ ),
+ TestCase(
+ ArrayStringSegment(value=[""]),
+ SegmentType.ARRAY_STRING,
+ ),
+ TestCase(
+ ArrayNumberSegment(value=[0, 0.0]),
+ SegmentType.ARRAY_NUMBER,
+ ),
+ TestCase(
+ ArrayObjectSegment(value=[{}]),
+ SegmentType.ARRAY_OBJECT,
+ ),
+ TestCase(
+ ArrayFileSegment(value=[file1, file2]),
+ SegmentType.ARRAY_FILE,
+ ),
+ ]
+
+ for test_case in cases:
+ segment = test_case.segment
+ assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
+ f"get_segment_discriminator failed for type {type(segment)}"
+ )
+ model_dict = segment.model_dump(mode="json")
+ assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
+ f"get_segment_discriminator failed for serialized form of type {type(segment)}"
+ )
+
+ def test_variable_discriminator_function_for_variable_types(self):
+ """Test the variable discriminator function"""
+
+ @dataclasses.dataclass
+ class TestCase:
+ variable: Variable
+ expected_segment_type: SegmentType
+
+ file1 = create_test_file()
+ file2 = create_test_file(filename="test2.txt")
+
+ cases = [
+ TestCase(
+ NoneVariable(name="none_var"),
+ SegmentType.NONE,
+ ),
+ TestCase(
+ StringVariable(value="test", name="string_var"),
+ SegmentType.STRING,
+ ),
+ TestCase(
+ FloatVariable(value=0.0, name="float_var"),
+ SegmentType.FLOAT,
+ ),
+ TestCase(
+ IntegerVariable(value=0, name="int_var"),
+ SegmentType.INTEGER,
+ ),
+ TestCase(
+ ObjectVariable(value={}, name="object_var"),
+ SegmentType.OBJECT,
+ ),
+ TestCase(
+ FileVariable(value=file1, name="file_var"),
+ SegmentType.FILE,
+ ),
+ TestCase(
+ SecretVariable(value="secret", name="secret_var"),
+ SegmentType.SECRET,
+ ),
+ TestCase(
+ ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
+ SegmentType.ARRAY_ANY,
+ ),
+ TestCase(
+ ArrayStringVariable(value=[""], name="array_string_var"),
+ SegmentType.ARRAY_STRING,
+ ),
+ TestCase(
+ ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
+ SegmentType.ARRAY_NUMBER,
+ ),
+ TestCase(
+ ArrayObjectVariable(value=[{}], name="array_object_var"),
+ SegmentType.ARRAY_OBJECT,
+ ),
+ TestCase(
+ ArrayFileVariable(value=[file1, file2], name="array_file_var"),
+ SegmentType.ARRAY_FILE,
+ ),
+ ]
+
+ for test_case in cases:
+ variable = test_case.variable
+ assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
+ f"get_segment_discriminator failed for type {type(variable)}"
+ )
+ model_dict = variable.model_dump(mode="json")
+ assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
+ f"get_segment_discriminator failed for serialized form of type {type(variable)}"
+ )
+
+ def test_invalid_value_for_discriminator(self):
+ # Test invalid cases
+ assert get_segment_discriminator({"value_type": "invalid"}) is None
+ assert get_segment_discriminator({}) is None
+ assert get_segment_discriminator("not_a_dict") is None
+ assert get_segment_discriminator(42) is None
+ assert get_segment_discriminator(object) is None
diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py
new file mode 100644
index 0000000000..64d0d8c7e7
--- /dev/null
+++ b/api/tests/unit_tests/core/variables/test_segment_type.py
@@ -0,0 +1,60 @@
+from core.variables.types import SegmentType
+
+
+class TestSegmentTypeIsArrayType:
+ """
+ Test class for SegmentType.is_array_type method.
+
+ Provides comprehensive coverage of all SegmentType values to ensure
+ correct identification of array and non-array types.
+ """
+
+ def test_is_array_type(self):
+ """
+ Test that all SegmentType enum values are covered in our test cases.
+
+ Ensures comprehensive coverage by verifying that every SegmentType
+ value is tested for the is_array_type method.
+ """
+ # Arrange
+ all_segment_types = set(SegmentType)
+ expected_array_types = [
+ SegmentType.ARRAY_ANY,
+ SegmentType.ARRAY_STRING,
+ SegmentType.ARRAY_NUMBER,
+ SegmentType.ARRAY_OBJECT,
+ SegmentType.ARRAY_FILE,
+ ]
+ expected_non_array_types = [
+ SegmentType.INTEGER,
+ SegmentType.FLOAT,
+ SegmentType.NUMBER,
+ SegmentType.STRING,
+ SegmentType.OBJECT,
+ SegmentType.SECRET,
+ SegmentType.FILE,
+ SegmentType.NONE,
+ SegmentType.GROUP,
+ ]
+
+ for seg_type in expected_array_types:
+ assert seg_type.is_array_type()
+
+ for seg_type in expected_non_array_types:
+ assert not seg_type.is_array_type()
+
+ # Act & Assert
+ covered_types = set(expected_array_types) | set(expected_non_array_types)
+ assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
+
+ def test_all_enum_values_are_supported(self):
+ """
+ Test that all enum values are supported and return boolean values.
+
+ Validates that every SegmentType enum value can be processed by
+ is_array_type method and returns a boolean value.
+ """
+ enum_values: list[SegmentType] = list(SegmentType)
+ for seg_type in enum_values:
+ is_array = seg_type.is_array_type()
+ assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py
index 426557c716..925142892c 100644
--- a/api/tests/unit_tests/core/variables/test_variables.py
+++ b/api/tests/unit_tests/core/variables/test_variables.py
@@ -11,6 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
+from core.variables.variables import Variable
def test_frozen_variables():
@@ -75,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
- var = StringVariable(name="text", value="text")
+ var: Variable = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py
new file mode 100644
index 0000000000..cf7cee8710
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py
@@ -0,0 +1,146 @@
+import time
+from decimal import Decimal
+
+from core.model_runtime.entities.llm_entities import LLMUsage
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
+from core.workflow.system_variable import SystemVariable
+
+
+def create_test_graph_runtime_state() -> GraphRuntimeState:
+ """Factory function to create a GraphRuntimeState with non-empty values for testing."""
+ # Create a variable pool with system variables
+ system_vars = SystemVariable(
+ user_id="test_user_123",
+ app_id="test_app_456",
+ workflow_id="test_workflow_789",
+ workflow_execution_id="test_execution_001",
+ query="test query",
+ conversation_id="test_conv_123",
+ dialogue_count=5,
+ )
+ variable_pool = VariablePool(system_variables=system_vars)
+
+ # Add some variables to the variable pool
+ variable_pool.add(["test_node", "test_var"], "test_value")
+ variable_pool.add(["another_node", "another_var"], 42)
+
+ # Create LLM usage with realistic values
+ llm_usage = LLMUsage(
+ prompt_tokens=150,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal(1000),
+ prompt_price=Decimal("0.15"),
+ completion_tokens=75,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal(1000),
+ completion_price=Decimal("0.15"),
+ total_tokens=225,
+ total_price=Decimal("0.30"),
+ currency="USD",
+ latency=1.25,
+ )
+
+ # Create runtime route state with some node states
+ node_run_state = RuntimeRouteState()
+ node_state = node_run_state.create_node_state("test_node_1")
+ node_run_state.add_route(node_state.id, "target_node_id")
+
+ return GraphRuntimeState(
+ variable_pool=variable_pool,
+ start_at=time.perf_counter(),
+ total_tokens=100,
+ llm_usage=llm_usage,
+ outputs={
+ "string_output": "test result",
+ "int_output": 42,
+ "float_output": 3.14,
+ "list_output": ["item1", "item2", "item3"],
+ "dict_output": {"key1": "value1", "key2": 123},
+ "nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
+ },
+ node_run_steps=5,
+ node_run_state=node_run_state,
+ )
+
+
+def test_basic_round_trip_serialization():
+ """Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
+ # Create a state with non-empty values
+ original_state = create_test_graph_runtime_state()
+
+ # Serialize to JSON and deserialize back
+ json_data = original_state.model_dump_json()
+ deserialized_state = GraphRuntimeState.model_validate_json(json_data)
+
+ # Core test: ensure the round-trip preserves all values
+ assert deserialized_state == original_state
+
+ # Serialize to JSON and deserialize back
+ dict_data = original_state.model_dump(mode="python")
+ deserialized_state = GraphRuntimeState.model_validate(dict_data)
+ assert deserialized_state == original_state
+
+ # Serialize to JSON and deserialize back
+ dict_data = original_state.model_dump(mode="json")
+ deserialized_state = GraphRuntimeState.model_validate(dict_data)
+ assert deserialized_state == original_state
+
+
+def test_outputs_field_round_trip():
+ """Test the problematic outputs field maintains values through round-trip serialization."""
+ original_state = create_test_graph_runtime_state()
+
+ # Serialize and deserialize
+ json_data = original_state.model_dump_json()
+ deserialized_state = GraphRuntimeState.model_validate_json(json_data)
+
+ # Verify the outputs field specifically maintains its values
+ assert deserialized_state.outputs == original_state.outputs
+ assert deserialized_state == original_state
+
+
+def test_empty_outputs_round_trip():
+ """Test round-trip serialization with empty outputs field."""
+ variable_pool = VariablePool.empty()
+ original_state = GraphRuntimeState(
+ variable_pool=variable_pool,
+ start_at=time.perf_counter(),
+ outputs={}, # Empty outputs
+ )
+
+ json_data = original_state.model_dump_json()
+ deserialized_state = GraphRuntimeState.model_validate_json(json_data)
+
+ assert deserialized_state == original_state
+
+
+def test_llm_usage_round_trip():
+ # Create LLM usage with specific decimal values
+ llm_usage = LLMUsage(
+ prompt_tokens=100,
+ prompt_unit_price=Decimal("0.0015"),
+ prompt_price_unit=Decimal(1000),
+ prompt_price=Decimal("0.15"),
+ completion_tokens=50,
+ completion_unit_price=Decimal("0.003"),
+ completion_price_unit=Decimal(1000),
+ completion_price=Decimal("0.15"),
+ total_tokens=150,
+ total_price=Decimal("0.30"),
+ currency="USD",
+ latency=2.5,
+ )
+
+ json_data = llm_usage.model_dump_json()
+ deserialized = LLMUsage.model_validate_json(json_data)
+ assert deserialized == llm_usage
+
+ dict_data = llm_usage.model_dump(mode="python")
+ deserialized = LLMUsage.model_validate(dict_data)
+ assert deserialized == llm_usage
+
+ dict_data = llm_usage.model_dump(mode="json")
+ deserialized = LLMUsage.model_validate(dict_data)
+ assert deserialized == llm_usage
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py
new file mode 100644
index 0000000000..f3de42479a
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py
@@ -0,0 +1,401 @@
+import json
+import uuid
+from datetime import UTC, datetime
+
+import pytest
+from pydantic import ValidationError
+
+from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
+from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
+
+_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
+
+
+class TestRouteNodeStateSerialization:
+ """Test cases for RouteNodeState Pydantic serialization/deserialization."""
+
+ def _test_route_node_state(self):
+ """Test comprehensive RouteNodeState serialization with all core fields validation."""
+
+ node_run_result = NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ inputs={"input_key": "input_value"},
+ outputs={"output_key": "output_value"},
+ )
+
+ node_state = RouteNodeState(
+ node_id="comprehensive_test_node",
+ start_at=_TEST_DATETIME,
+ finished_at=_TEST_DATETIME,
+ status=RouteNodeState.Status.SUCCESS,
+ node_run_result=node_run_result,
+ index=5,
+ paused_at=_TEST_DATETIME,
+ paused_by="user_123",
+ failed_reason="test_reason",
+ )
+ return node_state
+
+ def test_route_node_state_comprehensive_field_validation(self):
+ """Test comprehensive RouteNodeState serialization with all core fields validation."""
+ node_state = self._test_route_node_state()
+ serialized = node_state.model_dump()
+
+ # Comprehensive validation of all RouteNodeState fields
+ assert serialized["node_id"] == "comprehensive_test_node"
+ assert serialized["status"] == RouteNodeState.Status.SUCCESS
+ assert serialized["start_at"] == _TEST_DATETIME
+ assert serialized["finished_at"] == _TEST_DATETIME
+ assert serialized["paused_at"] == _TEST_DATETIME
+ assert serialized["paused_by"] == "user_123"
+ assert serialized["failed_reason"] == "test_reason"
+ assert serialized["index"] == 5
+ assert "id" in serialized
+ assert isinstance(serialized["id"], str)
+ uuid.UUID(serialized["id"]) # Validate UUID format
+
+ # Validate nested NodeRunResult structure
+ assert serialized["node_run_result"] is not None
+ assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
+ assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
+
+ def test_route_node_state_minimal_required_fields(self):
+ """Test RouteNodeState with only required fields, focusing on defaults."""
+ node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
+
+ serialized = node_state.model_dump()
+
+ # Focus on required fields and default values (not re-testing all fields)
+ assert serialized["node_id"] == "minimal_node"
+ assert serialized["start_at"] == _TEST_DATETIME
+ assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
+ assert serialized["index"] == 1 # Default index
+ assert serialized["node_run_result"] is None # Default None
+ json = node_state.model_dump_json()
+ deserialized = RouteNodeState.model_validate_json(json)
+ assert deserialized == node_state
+
+ def test_route_node_state_deserialization_from_dict(self):
+ """Test RouteNodeState deserialization from dictionary data."""
+ test_datetime = datetime(2024, 1, 15, 10, 30, 45)
+ test_id = str(uuid.uuid4())
+
+ dict_data = {
+ "id": test_id,
+ "node_id": "deserialized_node",
+ "start_at": test_datetime,
+ "status": "success",
+ "finished_at": test_datetime,
+ "index": 3,
+ }
+
+ node_state = RouteNodeState.model_validate(dict_data)
+
+ # Focus on deserialization accuracy
+ assert node_state.id == test_id
+ assert node_state.node_id == "deserialized_node"
+ assert node_state.start_at == test_datetime
+ assert node_state.status == RouteNodeState.Status.SUCCESS
+ assert node_state.finished_at == test_datetime
+ assert node_state.index == 3
+
+ def test_route_node_state_round_trip_consistency(self):
+ node_states = (
+ self._test_route_node_state(),
+ RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
+ )
+ for node_state in node_states:
+ json = node_state.model_dump_json()
+ deserialized = RouteNodeState.model_validate_json(json)
+ assert deserialized == node_state
+
+ dict_ = node_state.model_dump(mode="python")
+ deserialized = RouteNodeState.model_validate(dict_)
+ assert deserialized == node_state
+
+ dict_ = node_state.model_dump(mode="json")
+ deserialized = RouteNodeState.model_validate(dict_)
+ assert deserialized == node_state
+
+
+class TestRouteNodeStateEnumSerialization:
+ """Dedicated tests for RouteNodeState Status enum serialization behavior."""
+
+ def test_status_enum_model_dump_behavior(self):
+ """Test Status enum serialization in model_dump() returns enum objects."""
+
+ for status_enum in RouteNodeState.Status:
+ node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
+ serialized = node_state.model_dump(mode="python")
+ assert serialized["status"] == status_enum
+ serialized = node_state.model_dump(mode="json")
+ assert serialized["status"] == status_enum.value
+
+ def test_status_enum_json_serialization_behavior(self):
+ """Test Status enum serialization in JSON returns string values."""
+ test_datetime = datetime(2024, 1, 15, 10, 30, 45)
+
+ enum_to_string_mapping = {
+ RouteNodeState.Status.RUNNING: "running",
+ RouteNodeState.Status.SUCCESS: "success",
+ RouteNodeState.Status.FAILED: "failed",
+ RouteNodeState.Status.PAUSED: "paused",
+ RouteNodeState.Status.EXCEPTION: "exception",
+ }
+
+ for status_enum, expected_string in enum_to_string_mapping.items():
+ node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
+
+ json_data = json.loads(node_state.model_dump_json())
+ assert json_data["status"] == expected_string
+
+ def test_status_enum_deserialization_from_string(self):
+ """Test Status enum deserialization from string values."""
+ test_datetime = datetime(2024, 1, 15, 10, 30, 45)
+
+ string_to_enum_mapping = {
+ "running": RouteNodeState.Status.RUNNING,
+ "success": RouteNodeState.Status.SUCCESS,
+ "failed": RouteNodeState.Status.FAILED,
+ "paused": RouteNodeState.Status.PAUSED,
+ "exception": RouteNodeState.Status.EXCEPTION,
+ }
+
+ for status_string, expected_enum in string_to_enum_mapping.items():
+ dict_data = {
+ "node_id": "enum_deserialize_test",
+ "start_at": test_datetime,
+ "status": status_string,
+ }
+
+ node_state = RouteNodeState.model_validate(dict_data)
+ assert node_state.status == expected_enum
+
+
+class TestRuntimeRouteStateSerialization:
+ """Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
+
+ _NODE1_ID = "node_1"
+ _ROUTE_STATE1_ID = str(uuid.uuid4())
+ _NODE2_ID = "node_2"
+ _ROUTE_STATE2_ID = str(uuid.uuid4())
+ _NODE3_ID = "node_3"
+ _ROUTE_STATE3_ID = str(uuid.uuid4())
+
+ def _get_runtime_route_state(self):
+ # Create node states with different configurations
+ node_state_1 = RouteNodeState(
+ id=self._ROUTE_STATE1_ID,
+ node_id=self._NODE1_ID,
+ start_at=_TEST_DATETIME,
+ index=1,
+ )
+ node_state_2 = RouteNodeState(
+ id=self._ROUTE_STATE2_ID,
+ node_id=self._NODE2_ID,
+ start_at=_TEST_DATETIME,
+ status=RouteNodeState.Status.SUCCESS,
+ finished_at=_TEST_DATETIME,
+ index=2,
+ )
+ node_state_3 = RouteNodeState(
+ id=self._ROUTE_STATE3_ID,
+ node_id=self._NODE3_ID,
+ start_at=_TEST_DATETIME,
+ status=RouteNodeState.Status.FAILED,
+ failed_reason="Test failure",
+ index=3,
+ )
+
+ runtime_state = RuntimeRouteState(
+ routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
+ node_state_mapping={
+ node_state_1.id: node_state_1,
+ node_state_2.id: node_state_2,
+ node_state_3.id: node_state_3,
+ },
+ )
+
+ return runtime_state
+
+ def test_runtime_route_state_comprehensive_structure_validation(self):
+ """Test comprehensive RuntimeRouteState serialization with full structure validation."""
+
+ runtime_state = self._get_runtime_route_state()
+ serialized = runtime_state.model_dump()
+
+ # Comprehensive validation of RuntimeRouteState structure
+ assert "routes" in serialized
+ assert "node_state_mapping" in serialized
+ assert isinstance(serialized["routes"], dict)
+ assert isinstance(serialized["node_state_mapping"], dict)
+
+ # Validate routes dictionary structure and content
+ assert len(serialized["routes"]) == 2
+ assert self._ROUTE_STATE1_ID in serialized["routes"]
+ assert self._ROUTE_STATE2_ID in serialized["routes"]
+ assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
+ assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
+
+ # Validate node_state_mapping dictionary structure and content
+ assert len(serialized["node_state_mapping"]) == 3
+ for state_id in [
+ self._ROUTE_STATE1_ID,
+ self._ROUTE_STATE2_ID,
+ self._ROUTE_STATE3_ID,
+ ]:
+ assert state_id in serialized["node_state_mapping"]
+ node_data = serialized["node_state_mapping"][state_id]
+ node_state = runtime_state.node_state_mapping[state_id]
+ assert node_data["node_id"] == node_state.node_id
+ assert node_data["status"] == node_state.status
+ assert node_data["index"] == node_state.index
+
+ def test_runtime_route_state_empty_collections(self):
+ """Test RuntimeRouteState with empty collections, focusing on default behavior."""
+ runtime_state = RuntimeRouteState()
+ serialized = runtime_state.model_dump()
+
+ # Focus on default empty collection behavior
+ assert serialized["routes"] == {}
+ assert serialized["node_state_mapping"] == {}
+ assert isinstance(serialized["routes"], dict)
+ assert isinstance(serialized["node_state_mapping"], dict)
+
+ def test_runtime_route_state_json_serialization_structure(self):
+ """Test RuntimeRouteState JSON serialization structure."""
+ node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
+
+ runtime_state = RuntimeRouteState(
+ routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
+ )
+
+ json_str = runtime_state.model_dump_json()
+ json_data = json.loads(json_str)
+
+ # Focus on JSON structure validation
+ assert isinstance(json_str, str)
+ assert isinstance(json_data, dict)
+ assert "routes" in json_data
+ assert "node_state_mapping" in json_data
+ assert json_data["routes"]["source"] == ["target1", "target2"]
+ assert node_state.id in json_data["node_state_mapping"]
+
+ def test_runtime_route_state_deserialization_from_dict(self):
+ """Test RuntimeRouteState deserialization from dictionary data."""
+ node_id = str(uuid.uuid4())
+
+ dict_data = {
+ "routes": {"source_node": ["target_node_1", "target_node_2"]},
+ "node_state_mapping": {
+ node_id: {
+ "id": node_id,
+ "node_id": "test_node",
+ "start_at": _TEST_DATETIME,
+ "status": "running",
+ "index": 1,
+ }
+ },
+ }
+
+ runtime_state = RuntimeRouteState.model_validate(dict_data)
+
+ # Focus on deserialization accuracy
+ assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
+ assert len(runtime_state.node_state_mapping) == 1
+ assert node_id in runtime_state.node_state_mapping
+
+ deserialized_node = runtime_state.node_state_mapping[node_id]
+ assert deserialized_node.node_id == "test_node"
+ assert deserialized_node.status == RouteNodeState.Status.RUNNING
+ assert deserialized_node.index == 1
+
+ def test_runtime_route_state_round_trip_consistency(self):
+ """Test RuntimeRouteState round-trip serialization consistency."""
+ original = self._get_runtime_route_state()
+
+ # Dictionary round trip
+ dict_data = original.model_dump(mode="python")
+ reconstructed = RuntimeRouteState.model_validate(dict_data)
+ assert reconstructed == original
+
+ dict_data = original.model_dump(mode="json")
+ reconstructed = RuntimeRouteState.model_validate(dict_data)
+ assert reconstructed == original
+
+ # JSON round trip
+ json_str = original.model_dump_json()
+ json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
+ assert json_reconstructed == original
+
+
+class TestSerializationEdgeCases:
+ """Test edge cases and error conditions for serialization/deserialization."""
+
+ def test_invalid_status_deserialization(self):
+ """Test deserialization with invalid status values."""
+ test_datetime = _TEST_DATETIME
+ invalid_data = {
+ "node_id": "invalid_test",
+ "start_at": test_datetime,
+ "status": "invalid_status",
+ }
+
+ with pytest.raises(ValidationError) as exc_info:
+ RouteNodeState.model_validate(invalid_data)
+ assert "status" in str(exc_info.value)
+
+ def test_missing_required_fields_deserialization(self):
+ """Test deserialization with missing required fields."""
+ incomplete_data = {"id": str(uuid.uuid4())}
+
+ with pytest.raises(ValidationError) as exc_info:
+ RouteNodeState.model_validate(incomplete_data)
+ error_str = str(exc_info.value)
+ assert "node_id" in error_str or "start_at" in error_str
+
+ def test_invalid_datetime_deserialization(self):
+ """Test deserialization with invalid datetime values."""
+ invalid_data = {
+ "node_id": "datetime_test",
+ "start_at": "invalid_datetime",
+ "status": "running",
+ }
+
+ with pytest.raises(ValidationError) as exc_info:
+ RouteNodeState.model_validate(invalid_data)
+ assert "start_at" in str(exc_info.value)
+
+ def test_invalid_routes_structure_deserialization(self):
+ """Test RuntimeRouteState deserialization with invalid routes structure."""
+ invalid_data = {
+ "routes": "invalid_routes_structure", # Should be dict
+ "node_state_mapping": {},
+ }
+
+ with pytest.raises(ValidationError) as exc_info:
+ RuntimeRouteState.model_validate(invalid_data)
+ assert "routes" in str(exc_info.value)
+
+ def test_timezone_handling_in_datetime_fields(self):
+ """Test timezone handling in datetime field serialization."""
+ utc_datetime = datetime.now(UTC)
+ naive_datetime = utc_datetime.replace(tzinfo=None)
+
+ node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
+ dict_ = node_state.model_dump()
+
+ assert dict_["start_at"] == naive_datetime
+
+ # Test round trip
+ reconstructed = RouteNodeState.model_validate(dict_)
+ assert reconstructed.start_at == naive_datetime
+ assert reconstructed.start_at.tzinfo is None
+
+ json = node_state.model_dump_json()
+
+ reconstructed = RouteNodeState.model_validate_json(json)
+ assert reconstructed.start_at == naive_datetime
+ assert reconstructed.start_at.tzinfo is None
diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py
index 7535ec4866..ed4e42425e 100644
--- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py
+++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py
@@ -1,3 +1,4 @@
+import time
from unittest.mock import patch
import pytest
@@ -7,7 +8,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
@@ -19,12 +19,14 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -169,9 +171,11 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
+ system_variables=SystemVariable(user_id="aaa", app_id="1", workflow_id="1", files=[]),
+ user_inputs={"query": "hi"},
)
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@@ -183,7 +187,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
- variable_pool=variable_pool,
+ graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@@ -290,15 +294,16 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "what's the weather in SF",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ query="what's the weather in SF",
+ conversation_id="abababa",
+ ),
user_inputs={},
)
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@@ -310,7 +315,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
- variable_pool=variable_pool,
+ graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@@ -470,15 +475,16 @@ def test_run_branch(mock_close, mock_remove):
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "hi",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ query="hi",
+ conversation_id="abababa",
+ ),
user_inputs={"uid": "takato"},
)
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@@ -490,7 +496,7 @@ def test_run_branch(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
- variable_pool=variable_pool,
+ graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
@@ -799,20 +805,25 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ ),
+ user_inputs={"query": "hi"},
)
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
@@ -824,7 +835,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
- variable_pool=variable_pool,
+ graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
index b7f78d91fa..1ef024f46b 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py
@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
+from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -51,28 +51,33 @@ def test_execute_answer():
# construct variable pool
pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
)
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
+ node_config = {
+ "id": "answer",
+ "data": {
+ "title": "123",
+ "type": "answer",
+ "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+ },
+ }
+
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "id": "answer",
- "data": {
- "title": "123",
- "type": "answer",
- "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Mock db.session.close()
db.session.close = MagicMock()
diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
index c3a3818655..137e8b889d 100644
--- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
+++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py
@@ -3,7 +3,6 @@ from collections.abc import Generator
from datetime import UTC, datetime
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
@@ -15,6 +14,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
+from core.workflow.system_variable import SystemVariable
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
@@ -180,12 +180,12 @@ def test_process():
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "what's the weather in SF",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ query="what's the weather in SF",
+ conversation_id="abababa",
+ ),
user_inputs={},
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
index d066fc1e33..bb6d72f51e 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py
@@ -7,12 +7,13 @@ from core.workflow.nodes.http_request import (
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor
+from core.workflow.system_variable import SystemVariable
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
@@ -65,7 +66,7 @@ def test_executor_with_json_body_and_number_variable():
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@@ -120,7 +121,7 @@ def test_executor_with_json_body_and_object_variable():
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
@@ -174,7 +175,7 @@ def test_executor_with_json_body_and_nested_object_variable():
def test_extract_selectors_from_template_with_newline():
- variable_pool = VariablePool()
+ variable_pool = VariablePool(system_variables=SystemVariable.empty())
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
@@ -201,7 +202,7 @@ def test_extract_selectors_from_template_with_newline():
def test_executor_with_form_data():
# Prepare the variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(["pre_node_id", "text_field"], "Hello, World!")
@@ -280,7 +281,11 @@ def test_init_headers():
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
- return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
+ return Executor(
+ node_data=node_data,
+ timeout=timeout,
+ variable_pool=VariablePool(system_variables=SystemVariable.empty()),
+ )
executor = create_executor("aa\n cc:")
executor._init_headers()
@@ -310,7 +315,11 @@ def test_init_params():
authorization=HttpRequestNodeAuthorization(type="no-auth"),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
- return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool())
+ return Executor(
+ node_data=node_data,
+ timeout=timeout,
+ variable_pool=VariablePool(system_variables=SystemVariable.empty()),
+ )
# Test basic key-value pairs
executor = create_executor("key1:value1\nkey2:value2")
diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
index 7fd32a4826..71b3a8f7d8 100644
--- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py
@@ -15,6 +15,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeBody,
HttpRequestNodeData,
)
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -40,7 +41,7 @@ def test_http_request_node_binary_file(monkeypatch):
),
)
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
@@ -56,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
),
),
)
+
+ node_config = {
+ "id": "1",
+ "data": data.model_dump(),
+ }
+
node = HttpRequestNode(
id="1",
- config={
- "id": "1",
- "data": data.model_dump(),
- },
+ config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -89,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
start_at=0,
),
)
+
+ # Initialize node data
+ node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@@ -128,7 +135,7 @@ def test_http_request_node_form_with_file(monkeypatch):
),
)
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
variable_pool.add(
@@ -144,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
),
),
)
+
+ node_config = {
+ "id": "1",
+ "data": data.model_dump(),
+ }
+
node = HttpRequestNode(
id="1",
- config={
- "id": "1",
- "data": data.model_dump(),
- },
+ config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -177,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
start_at=0,
),
)
+
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@@ -223,7 +237,7 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
)
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
@@ -256,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
+ node_config = {
+ "id": "1",
+ "data": data.model_dump(),
+ }
+
node = HttpRequestNode(
id="1",
- config={
- "id": "1",
- "data": data.model_dump(),
- },
+ config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -290,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",
diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py
index 362072a3db..f53f391433 100644
--- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py
+++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py
@@ -7,7 +7,6 @@ from core.variables.segments import ArrayAnySegment, ArrayStringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@@ -15,6 +14,7 @@ from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -151,36 +151,41 @@ def test_run():
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
+ node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "tt",
+ "title": "迭代",
+ "type": "iteration",
+ },
+ "id": "iteration-1",
+ }
+
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "tt",
- "title": "迭代",
- "type": "iteration",
- },
- "id": "iteration-1",
- },
+ config=node_config,
)
+ # Initialize node data
+ iteration_node.init_node_data(node_config["data"])
+
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -368,36 +373,41 @@ def test_run_parallel():
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
+ node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "iteration-start",
+ "title": "迭代",
+ "type": "iteration",
+ },
+ "id": "iteration-1",
+ }
+
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "iteration-start",
- "title": "迭代",
- "type": "iteration",
- },
- "id": "iteration-1",
- },
+ config=node_config,
)
+ # Initialize node data
+ iteration_node.init_node_data(node_config["data"])
+
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -584,56 +594,66 @@ def test_iteration_run_in_parallel_mode():
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
+ parallel_node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "iteration-start",
+ "title": "迭代",
+ "type": "iteration",
+ "is_parallel": True,
+ },
+ "id": "iteration-1",
+ }
+
parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "iteration-start",
- "title": "迭代",
- "type": "iteration",
- "is_parallel": True,
- },
- "id": "iteration-1",
- },
+ config=parallel_node_config,
)
+
+ # Initialize node data
+ parallel_iteration_node.init_node_data(parallel_node_config["data"])
+ sequential_node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "iteration-start",
+ "title": "迭代",
+ "type": "iteration",
+ "is_parallel": True,
+ },
+ "id": "iteration-1",
+ }
+
sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "iteration-start",
- "title": "迭代",
- "type": "iteration",
- "is_parallel": True,
- },
- "id": "iteration-1",
- },
+ config=sequential_node_config,
)
+ # Initialize node data
+ sequential_iteration_node.init_node_data(sequential_node_config["data"])
+
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
# execute node
parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run()
- assert parallel_iteration_node.node_data.parallel_nums == 10
- assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
+ assert parallel_iteration_node._node_data.parallel_nums == 10
+ assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0
parallel_arr = []
sequential_arr = []
@@ -808,36 +828,41 @@ def test_iteration_run_error_handle():
# construct variable pool
pool = VariablePool(
- system_variables={
- SystemVariableKey.QUERY: "dify",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "1",
- },
+ system_variables=SystemVariable(
+ user_id="1",
+ files=[],
+ query="dify",
+ conversation_id="abababa",
+ ),
user_inputs={},
environment_variables=[],
)
pool.add(["pe", "list_output"], ["1", "1"])
+ error_node_config = {
+ "data": {
+ "iterator_selector": ["pe", "list_output"],
+ "output_selector": ["tt", "output"],
+ "output_type": "array[string]",
+ "startNodeType": "template-transform",
+ "start_node_id": "iteration-start",
+ "title": "iteration",
+ "type": "iteration",
+ "is_parallel": True,
+ "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
+ },
+ "id": "iteration-1",
+ }
+
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "data": {
- "iterator_selector": ["pe", "list_output"],
- "output_selector": ["tt", "output"],
- "output_type": "array[string]",
- "startNodeType": "template-transform",
- "start_node_id": "iteration-start",
- "title": "iteration",
- "type": "iteration",
- "is_parallel": True,
- "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
- },
- "id": "iteration-1",
- },
+ config=error_node_config,
)
+
+ # Initialize node data
+ iteration_node.init_node_data(error_node_config["data"])
# execute continue on error node
result = iteration_node._run()
result_arr = []
@@ -851,7 +876,7 @@ def test_iteration_run_error_handle():
assert count == 14
# execute remove abnormal output
- iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
+ iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run()
count = 0
for item in result:
diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
index 336c2befcc..23a7fab7cf 100644
--- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
@@ -36,6 +36,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.provider import ProviderType
from models.workflow import WorkflowType
@@ -104,7 +105,7 @@ def graph() -> Graph:
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
return GraphRuntimeState(
@@ -118,17 +119,20 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
+ node_config = {
+ "id": "1",
+ "data": llm_node_data.model_dump(),
+ }
node = LLMNode(
id="1",
- config={
- "id": "1",
- "data": llm_node_data.model_dump(),
- },
+ config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
return node
@@ -181,7 +185,7 @@ def test_fetch_files_with_file_segment():
related_id="1",
storage_key="",
)
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], file)
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -209,7 +213,7 @@ def test_fetch_files_with_array_file_segment():
storage_key="",
),
]
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -217,7 +221,7 @@ def test_fetch_files_with_array_file_segment():
def test_fetch_files_with_none_segment():
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], NoneSegment())
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -225,7 +229,7 @@ def test_fetch_files_with_none_segment():
def test_fetch_files_with_array_any_segment():
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
@@ -233,7 +237,7 @@ def test_fetch_files_with_array_any_segment():
def test_fetch_files_with_non_existent_variable():
- variable_pool = VariablePool()
+ variable_pool = VariablePool.empty()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == []
@@ -487,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
- result = llm_node._handle_list_messages(
+ result = llm_node.handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
@@ -505,17 +509,20 @@ def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
+ node_config = {
+ "id": "1",
+ "data": llm_node_data.model_dump(),
+ }
node = LLMNode(
id="1",
- config={
- "id": "1",
- "data": llm_node_data.model_dump(),
- },
+ config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
return node, mock_file_saver
@@ -539,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_binary_string.return_value = mock_file
- file = llm_node._save_multimodal_image_output(content=content)
+ file = llm_node.save_multimodal_image_output(
+ content=content,
+ file_saver=mock_file_saver,
+ )
+ # Manually append to _file_outputs since the static method doesn't do it
+ llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with(
@@ -565,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_remote_url.return_value = mock_file
- file = llm_node._save_multimodal_image_output(content=content)
+ file = llm_node.save_multimodal_image_output(
+ content=content,
+ file_saver=mock_file_saver,
+ )
+ # Manually append to _file_outputs since the static method doesn't do it
+ llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
@@ -581,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
- gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
+ gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
+ contents="hello world", file_saver=mock_file_saver, file_outputs=[]
+ )
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
@@ -589,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
- [TextPromptMessageContent(data="hello world")]
+ contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
@@ -615,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
)
mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
- [
+ contents=[
ImagePromptMessageContent(
format="png",
base64_data=image_b64_data,
mime_type="image/png",
)
- ]
+ ],
+ file_saver=mock_file_saver,
+ file_outputs=llm_node._file_outputs,
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
@@ -644,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
- gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
+ gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
+ contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
+ )
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
- gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
+ gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
+ contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
+ )
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
- gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
+ gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
+ contents=None, file_saver=mock_file_saver, file_outputs=[]
+ )
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py
index abc822e98b..466d7bad06 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py
@@ -5,11 +5,11 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
+from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -53,7 +53,7 @@ def test_execute_answer():
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
@@ -61,21 +61,26 @@ def test_execute_answer():
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
+ node_config = {
+ "id": "answer",
+ "data": {
+ "title": "123",
+ "type": "answer",
+ "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
+ },
+ }
+
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "answer",
- "data": {
- "title": "123",
- "type": "answer",
- "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Mock db.session.close()
db.session.close = MagicMock()
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py
index ff60d5974b..3f83428834 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py
@@ -1,9 +1,10 @@
+import time
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
+from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
@@ -11,9 +12,11 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -163,15 +166,16 @@ class ContinueOnErrorTestHelper:
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
- variable_pool = {
- "system_variables": {
- SystemVariableKey.QUERY: "clear",
- SystemVariableKey.FILES: [],
- SystemVariableKey.CONVERSATION_ID: "abababa",
- SystemVariableKey.USER_ID: "aaa",
- },
- "user_inputs": user_inputs or {"uid": "takato"},
- }
+ variable_pool = VariablePool(
+ system_variables=SystemVariable(
+ user_id="aaa",
+ files=[],
+ query="clear",
+ conversation_id="abababa",
+ ),
+ user_inputs=user_inputs or {"uid": "takato"},
+ )
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
return GraphEngine(
tenant_id="111",
@@ -184,7 +188,7 @@ class ContinueOnErrorTestHelper:
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
- variable_pool=variable_pool,
+ graph_runtime_state=graph_runtime_state,
max_execution_steps=500,
max_execution_time=1200,
)
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
index 66c7818adf..486ae51e5f 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py
@@ -27,13 +27,17 @@ def document_extractor_node():
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
- return DocumentExtractorNode(
+ node_config = {"id": "test_node_id", "data": node_data.model_dump()}
+ node = DocumentExtractorNode(
id="test_node_id",
- config={"id": "test_node_id", "data": node_data.model_dump()},
+ config=node_config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+ return node
@pytest.fixture
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
index c4e411f9d6..8383aee0e4 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py
@@ -7,12 +7,12 @@ from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
+from core.workflow.system_variable import SystemVariable
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
from extensions.ext_database import db
from models.enums import UserFrom
@@ -37,9 +37,7 @@ def test_execute_if_else_result_true():
)
# construct variable pool
- pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={}
- )
+ pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={})
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
@@ -59,57 +57,62 @@ def test_execute_if_else_result_true():
pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212")
+ node_config = {
+ "id": "if-else",
+ "data": {
+ "title": "123",
+ "type": "if-else",
+ "logical_operator": "and",
+ "conditions": [
+ {
+ "comparison_operator": "contains",
+ "variable_selector": ["start", "array_contains"],
+ "value": "ab",
+ },
+ {
+ "comparison_operator": "not contains",
+ "variable_selector": ["start", "array_not_contains"],
+ "value": "ab",
+ },
+ {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
+ {
+ "comparison_operator": "not contains",
+ "variable_selector": ["start", "not_contains"],
+ "value": "ab",
+ },
+ {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
+ {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
+ {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
+ {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
+ {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
+ {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
+ {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
+ {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
+ {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
+ {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
+ {
+ "comparison_operator": "≥",
+ "variable_selector": ["start", "greater_than_or_equal"],
+ "value": "22",
+ },
+ {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
+ {"comparison_operator": "null", "variable_selector": ["start", "null"]},
+ {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
+ ],
+ },
+ }
+
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "id": "if-else",
- "data": {
- "title": "123",
- "type": "if-else",
- "logical_operator": "and",
- "conditions": [
- {
- "comparison_operator": "contains",
- "variable_selector": ["start", "array_contains"],
- "value": "ab",
- },
- {
- "comparison_operator": "not contains",
- "variable_selector": ["start", "array_not_contains"],
- "value": "ab",
- },
- {"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
- {
- "comparison_operator": "not contains",
- "variable_selector": ["start", "not_contains"],
- "value": "ab",
- },
- {"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
- {"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
- {"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
- {"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
- {"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
- {"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
- {"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
- {"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
- {"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
- {"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
- {
- "comparison_operator": "≥",
- "variable_selector": ["start", "greater_than_or_equal"],
- "value": "22",
- },
- {"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
- {"comparison_operator": "null", "variable_selector": ["start", "null"]},
- {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Mock db.session.close()
db.session.close = MagicMock()
@@ -157,40 +160,45 @@ def test_execute_if_else_result_false():
# construct variable pool
pool = VariablePool(
- system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
+ system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
)
pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"])
+ node_config = {
+ "id": "if-else",
+ "data": {
+ "title": "123",
+ "type": "if-else",
+ "logical_operator": "or",
+ "conditions": [
+ {
+ "comparison_operator": "contains",
+ "variable_selector": ["start", "array_contains"],
+ "value": "ab",
+ },
+ {
+ "comparison_operator": "not contains",
+ "variable_selector": ["start", "array_not_contains"],
+ "value": "ab",
+ },
+ ],
+ },
+ }
+
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
- config={
- "id": "if-else",
- "data": {
- "title": "123",
- "type": "if-else",
- "logical_operator": "or",
- "conditions": [
- {
- "comparison_operator": "contains",
- "variable_selector": ["start", "array_contains"],
- "value": "ab",
- },
- {
- "comparison_operator": "not contains",
- "variable_selector": ["start", "array_not_contains"],
- "value": "ab",
- },
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Mock db.session.close()
db.session.close = MagicMock()
@@ -230,17 +238,22 @@ def test_array_file_contains_file_name():
],
)
+ node_config = {
+ "id": "if-else",
+ "data": node_data.model_dump(),
+ }
+
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
- config={
- "id": "if-else",
- "data": node_data.model_dump(),
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
index 7d3a1d6a2d..5fc9eab2df 100644
--- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
+++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py
@@ -33,16 +33,19 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
+ node_config = {
+ "id": "test_node_id",
+ "data": node_data.model_dump(),
+ }
node = ListOperatorNode(
id="test_node_id",
- config={
- "id": "test_node_id",
- "data": node_data.model_dump(),
- },
+ config=node_config,
graph_init_params=MagicMock(),
graph=MagicMock(),
graph_runtime_state=MagicMock(),
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node
diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
index e121f6338c..0eaabd0c40 100644
--- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
+++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
@@ -15,6 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
+from core.workflow.system_variable import SystemVariable
from models import UserFrom, WorkflowType
@@ -34,15 +35,16 @@ def _create_tool_node():
version="1",
)
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
)
+ node_config = {
+ "id": "1",
+ "data": data.model_dump(),
+ }
node = ToolNode(
id="1",
- config={
- "id": "1",
- "data": data.model_dump(),
- },
+ config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -70,6 +72,8 @@ def _create_tool_node():
start_at=0,
),
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
return node
diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py
index deb3e29b86..ee51339427 100644
--- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py
+++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py
@@ -7,12 +7,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -68,7 +68,7 @@ def test_overwrite_string_variable():
# construct variable pool
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
+ system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -82,23 +82,28 @@ def test_overwrite_string_variable():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "assigned_variable_selector": ["conversation", conversation_variable.name],
+ "write_mode": WriteMode.OVER_WRITE.value,
+ "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "assigned_variable_selector": ["conversation", conversation_variable.name],
- "write_mode": WriteMode.OVER_WRITE.value,
- "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
- },
- },
+ config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
@@ -165,7 +170,7 @@ def test_append_variable_to_array():
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
+ system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -178,23 +183,28 @@ def test_append_variable_to_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "assigned_variable_selector": ["conversation", conversation_variable.name],
+ "write_mode": WriteMode.APPEND.value,
+ "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "assigned_variable_selector": ["conversation", conversation_variable.name],
- "write_mode": WriteMode.APPEND.value,
- "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
- },
- },
+ config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
@@ -256,7 +266,7 @@ def test_clear_array():
conversation_id = str(uuid.uuid4())
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id},
+ system_variables=SystemVariable(conversation_id=conversation_id),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -265,23 +275,28 @@ def test_clear_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "assigned_variable_selector": ["conversation", conversation_variable.name],
+ "write_mode": WriteMode.CLEAR.value,
+ "input_variable_selector": [],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "assigned_variable_selector": ["conversation", conversation_variable.name],
- "write_mode": WriteMode.CLEAR.value,
- "input_variable_selector": [],
- },
- },
+ config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,
diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py
index 7c5597dd89..987eaf7534 100644
--- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py
+++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py
@@ -5,12 +5,12 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
+from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -109,34 +109,39 @@ def test_remove_first_from_array():
)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
+ system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "version": "2",
+ "items": [
+ {
+ "variable_selector": ["conversation", conversation_variable.name],
+ "input_type": InputType.VARIABLE,
+ "operation": Operation.REMOVE_FIRST,
+ "value": None,
+ }
+ ],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "version": "2",
- "items": [
- {
- "variable_selector": ["conversation", conversation_variable.name],
- "input_type": InputType.VARIABLE,
- "operation": Operation.REMOVE_FIRST,
- "value": None,
- }
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Skip the mock assertion since we're in a test environment
# Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
@@ -196,34 +201,39 @@ def test_remove_last_from_array():
)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
+ system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "version": "2",
+ "items": [
+ {
+ "variable_selector": ["conversation", conversation_variable.name],
+ "input_type": InputType.VARIABLE,
+ "operation": Operation.REMOVE_LAST,
+ "value": None,
+ }
+ ],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "version": "2",
- "items": [
- {
- "variable_selector": ["conversation", conversation_variable.name],
- "input_type": InputType.VARIABLE,
- "operation": Operation.REMOVE_LAST,
- "value": None,
- }
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Skip the mock assertion since we're in a test environment
list(node.run())
@@ -275,34 +285,39 @@ def test_remove_first_from_empty_array():
)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
+ system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "version": "2",
+ "items": [
+ {
+ "variable_selector": ["conversation", conversation_variable.name],
+ "input_type": InputType.VARIABLE,
+ "operation": Operation.REMOVE_FIRST,
+ "value": None,
+ }
+ ],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "version": "2",
- "items": [
- {
- "variable_selector": ["conversation", conversation_variable.name],
- "input_type": InputType.VARIABLE,
- "operation": Operation.REMOVE_FIRST,
- "value": None,
- }
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Skip the mock assertion since we're in a test environment
list(node.run())
@@ -354,34 +369,39 @@ def test_remove_last_from_empty_array():
)
variable_pool = VariablePool(
- system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
+ system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
+ node_config = {
+ "id": "node_id",
+ "data": {
+ "title": "test",
+ "version": "2",
+ "items": [
+ {
+ "variable_selector": ["conversation", conversation_variable.name],
+ "input_type": InputType.VARIABLE,
+ "operation": Operation.REMOVE_LAST,
+ "value": None,
+ }
+ ],
+ },
+ }
+
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
- config={
- "id": "node_id",
- "data": {
- "title": "test",
- "version": "2",
- "items": [
- {
- "variable_selector": ["conversation", conversation_variable.name],
- "input_type": InputType.VARIABLE,
- "operation": Operation.REMOVE_LAST,
- "value": None,
- }
- ],
- },
- },
+ config=node_config,
)
+ # Initialize node data
+ node.init_node_data(node_config["data"])
+
# Skip the mock assertion since we're in a test environment
list(node.run())
diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py
new file mode 100644
index 0000000000..11d788ed79
--- /dev/null
+++ b/api/tests/unit_tests/core/workflow/test_system_variable.py
@@ -0,0 +1,251 @@
+import json
+from typing import Any
+
+import pytest
+from pydantic import ValidationError
+
+from core.file.enums import FileTransferMethod, FileType
+from core.file.models import File
+from core.workflow.system_variable import SystemVariable
+
+# Test data constants for SystemVariable serialization tests
+VALID_BASE_DATA: dict[str, Any] = {
+ "user_id": "a20f06b1-8703-45ab-937c-860a60072113",
+ "app_id": "661bed75-458d-49c9-b487-fda0762677b9",
+ "workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43",
+}
+
+COMPLETE_VALID_DATA: dict[str, Any] = {
+ **VALID_BASE_DATA,
+ "query": "test query",
+ "files": [],
+ "conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9",
+ "dialogue_count": 5,
+ "workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5",
+}
+
+
+def create_test_file() -> File:
+ """Create a test File object for serialization tests."""
+ return File(
+ tenant_id="test-tenant-id",
+ type=FileType.DOCUMENT,
+ transfer_method=FileTransferMethod.LOCAL_FILE,
+ related_id="test-file-id",
+ filename="test.txt",
+ extension=".txt",
+ mime_type="text/plain",
+ size=1024,
+ storage_key="test-storage-key",
+ )
+
+
+class TestSystemVariableSerialization:
+ """Focused tests for SystemVariable serialization/deserialization logic."""
+
+ def test_basic_deserialization(self):
+ """Test successful deserialization from JSON structure with all fields correctly mapped."""
+ # Test with complete data
+ system_var = SystemVariable(**COMPLETE_VALID_DATA)
+
+ # Verify all fields are correctly mapped
+ assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
+ assert system_var.app_id == COMPLETE_VALID_DATA["app_id"]
+ assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"]
+ assert system_var.query == COMPLETE_VALID_DATA["query"]
+ assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"]
+ assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"]
+ assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"]
+ assert system_var.files == []
+
+ # Test with minimal data (only required fields)
+ minimal_var = SystemVariable(**VALID_BASE_DATA)
+ assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
+ assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
+ assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
+ assert minimal_var.query is None
+ assert minimal_var.conversation_id is None
+ assert minimal_var.dialogue_count is None
+ assert minimal_var.workflow_execution_id is None
+ assert minimal_var.files == []
+
+ def test_alias_handling(self):
+ """Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic."""
+ workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5"
+
+ # Test workflow_run_id only (preferred alias)
+ data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
+ system_var1 = SystemVariable(**data_run_id)
+ assert system_var1.workflow_execution_id == workflow_id
+
+ # Test workflow_execution_id only (direct field name)
+ data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
+ system_var2 = SystemVariable(**data_execution_id)
+ assert system_var2.workflow_execution_id == workflow_id
+
+ # Test both present - workflow_run_id should take precedence
+ data_both = {
+ **VALID_BASE_DATA,
+ "workflow_execution_id": "should-be-ignored",
+ "workflow_run_id": workflow_id,
+ }
+ system_var3 = SystemVariable(**data_both)
+ assert system_var3.workflow_execution_id == workflow_id
+
+ # Test neither present - should be None
+ system_var4 = SystemVariable(**VALID_BASE_DATA)
+ assert system_var4.workflow_execution_id is None
+
+ def test_serialization_round_trip(self):
+ """Test that serialize → deserialize produces the same result with alias handling."""
+ # Create original SystemVariable
+ original = SystemVariable(**COMPLETE_VALID_DATA)
+
+ # Serialize to dict
+ serialized = original.model_dump(mode="json")
+
+ # Verify alias is used in serialization (workflow_run_id, not workflow_execution_id)
+ assert "workflow_run_id" in serialized
+ assert "workflow_execution_id" not in serialized
+ assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
+
+ # Deserialize back
+ deserialized = SystemVariable(**serialized)
+
+ # Verify all fields match after round-trip
+ assert deserialized.user_id == original.user_id
+ assert deserialized.app_id == original.app_id
+ assert deserialized.workflow_id == original.workflow_id
+ assert deserialized.query == original.query
+ assert deserialized.conversation_id == original.conversation_id
+ assert deserialized.dialogue_count == original.dialogue_count
+ assert deserialized.workflow_execution_id == original.workflow_execution_id
+ assert list(deserialized.files) == list(original.files)
+
+ def test_json_round_trip(self):
+ """Test JSON serialization/deserialization consistency with proper structure."""
+ # Create original SystemVariable
+ original = SystemVariable(**COMPLETE_VALID_DATA)
+
+ # Serialize to JSON string
+ json_str = original.model_dump_json()
+
+ # Parse JSON and verify structure
+ json_data = json.loads(json_str)
+ assert "workflow_run_id" in json_data
+ assert "workflow_execution_id" not in json_data
+ assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
+
+ # Deserialize from JSON data
+ deserialized = SystemVariable(**json_data)
+
+ # Verify key fields match after JSON round-trip
+ assert deserialized.workflow_execution_id == original.workflow_execution_id
+ assert deserialized.user_id == original.user_id
+ assert deserialized.app_id == original.app_id
+ assert deserialized.workflow_id == original.workflow_id
+
+ def test_files_field_deserialization(self):
+ """Test deserialization with File objects in the files field - SystemVariable specific logic."""
+ # Test with empty files list
+ data_empty = {**VALID_BASE_DATA, "files": []}
+ system_var_empty = SystemVariable(**data_empty)
+ assert system_var_empty.files == []
+
+ # Test with single File object
+ test_file = create_test_file()
+ data_single = {**VALID_BASE_DATA, "files": [test_file]}
+ system_var_single = SystemVariable(**data_single)
+ assert len(system_var_single.files) == 1
+ assert system_var_single.files[0].filename == "test.txt"
+ assert system_var_single.files[0].tenant_id == "test-tenant-id"
+
+ # Test with multiple File objects
+ file1 = File(
+ tenant_id="tenant1",
+ type=FileType.DOCUMENT,
+ transfer_method=FileTransferMethod.LOCAL_FILE,
+ related_id="file1",
+ filename="doc1.txt",
+ storage_key="key1",
+ )
+ file2 = File(
+ tenant_id="tenant2",
+ type=FileType.IMAGE,
+ transfer_method=FileTransferMethod.REMOTE_URL,
+ remote_url="https://example.com/image.jpg",
+ filename="image.jpg",
+ storage_key="key2",
+ )
+
+ data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
+ system_var_multiple = SystemVariable(**data_multiple)
+ assert len(system_var_multiple.files) == 2
+ assert system_var_multiple.files[0].filename == "doc1.txt"
+ assert system_var_multiple.files[1].filename == "image.jpg"
+
+ # Verify files field serialization/deserialization
+ serialized = system_var_multiple.model_dump(mode="json")
+ deserialized = SystemVariable(**serialized)
+ assert len(deserialized.files) == 2
+ assert deserialized.files[0].filename == "doc1.txt"
+ assert deserialized.files[1].filename == "image.jpg"
+
+ def test_alias_serialization_consistency(self):
+ """Test that alias handling works consistently in both serialization directions."""
+ workflow_id = "test-workflow-id"
+
+ # Create with workflow_run_id (alias)
+ data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
+ system_var = SystemVariable(**data_with_alias)
+
+ # Serialize and verify alias is used
+ serialized = system_var.model_dump()
+ assert serialized["workflow_run_id"] == workflow_id
+ assert "workflow_execution_id" not in serialized
+
+ # Deserialize and verify field mapping
+ deserialized = SystemVariable(**serialized)
+ assert deserialized.workflow_execution_id == workflow_id
+
+ # Test JSON serialization path
+ json_serialized = json.loads(system_var.model_dump_json())
+ assert json_serialized["workflow_run_id"] == workflow_id
+ assert "workflow_execution_id" not in json_serialized
+
+ json_deserialized = SystemVariable(**json_serialized)
+ assert json_deserialized.workflow_execution_id == workflow_id
+
+ def test_model_validator_serialization_logic(self):
+ """Test the custom model validator behavior for serialization scenarios."""
+ workflow_id = "test-workflow-execution-id"
+
+ # Test direct instantiation with workflow_execution_id (should work)
+ data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
+ system_var1 = SystemVariable(**data1)
+ assert system_var1.workflow_execution_id == workflow_id
+
+ # Test serialization of the above (should use alias)
+ serialized1 = system_var1.model_dump()
+ assert "workflow_run_id" in serialized1
+ assert serialized1["workflow_run_id"] == workflow_id
+
+ # Test both present - workflow_run_id takes precedence (validator logic)
+ data2 = {
+ **VALID_BASE_DATA,
+ "workflow_execution_id": "should-be-removed",
+ "workflow_run_id": workflow_id,
+ }
+ system_var2 = SystemVariable(**data2)
+ assert system_var2.workflow_execution_id == workflow_id
+
+ # Verify serialization consistency
+ serialized2 = system_var2.model_dump()
+ assert serialized2["workflow_run_id"] == workflow_id
+
+
+def test_constructor_with_extra_key():
+ # Test that SystemVariable should forbid extra keys
+ with pytest.raises(ValidationError):
+ # This should fail because there is an unexpected key.
+ SystemVariable(invalid_key=1) # type: ignore
diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py
index bb8d34fad5..c65b60cb4d 100644
--- a/api/tests/unit_tests/core/workflow/test_variable_pool.py
+++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py
@@ -1,17 +1,43 @@
+import uuid
+from collections import defaultdict
+
import pytest
-from pydantic import ValidationError
from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment
-from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID
+from core.variables.segments import (
+ ArrayAnySegment,
+ ArrayFileSegment,
+ ArrayNumberSegment,
+ ArrayObjectSegment,
+ ArrayStringSegment,
+ FloatSegment,
+ IntegerSegment,
+ NoneSegment,
+ ObjectSegment,
+)
+from core.variables.variables import (
+ ArrayNumberVariable,
+ ArrayObjectVariable,
+ ArrayStringVariable,
+ FloatVariable,
+ IntegerVariable,
+ ObjectVariable,
+ StringVariable,
+ VariableUnion,
+)
+from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
from factories.variable_factory import build_segment, segment_to_variable
@pytest.fixture
def pool():
- return VariablePool(system_variables={}, user_inputs={})
+ return VariablePool(
+ system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"),
+ user_inputs={},
+ )
@pytest.fixture
@@ -52,18 +78,28 @@ def test_use_long_selector(pool):
class TestVariablePool:
def test_constructor(self):
- pool = VariablePool()
+ # Test with minimal required SystemVariable
+ minimal_system_vars = SystemVariable(
+ user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
+ )
+ pool = VariablePool(system_variables=minimal_system_vars)
+
+ # Test with all parameters
pool = VariablePool(
variable_dictionary={},
user_inputs={},
- system_variables={},
+ system_variables=minimal_system_vars,
environment_variables=[],
conversation_variables=[],
)
+ # Test with more complex SystemVariable
+ complex_system_vars = SystemVariable(
+ user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
+ )
pool = VariablePool(
user_inputs={"key": "value"},
- system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"},
+ system_variables=complex_system_vars,
environment_variables=[
segment_to_variable(
segment=build_segment(1),
@@ -80,6 +116,323 @@ class TestVariablePool:
],
)
- def test_constructor_with_invalid_system_variable_key(self):
- with pytest.raises(ValidationError):
- VariablePool(system_variables={"invalid_key": "value"}) # type: ignore
+ def test_get_system_variables(self):
+ sys_var = SystemVariable(
+ user_id="test_user_id",
+ app_id="test_app_id",
+ workflow_id="test_workflow_id",
+ workflow_execution_id="test_execution_123",
+ query="test query",
+ conversation_id="test_conv_id",
+ dialogue_count=5,
+ )
+ pool = VariablePool(system_variables=sys_var)
+
+ kv = [
+ ("user_id", sys_var.user_id),
+ ("app_id", sys_var.app_id),
+ ("workflow_id", sys_var.workflow_id),
+ ("workflow_run_id", sys_var.workflow_execution_id),
+ ("query", sys_var.query),
+ ("conversation_id", sys_var.conversation_id),
+ ("dialogue_count", sys_var.dialogue_count),
+ ]
+ for key, expected_value in kv:
+ segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key])
+ assert segment is not None
+ assert segment.value == expected_value
+
+
+class TestVariablePoolSerialization:
+ """Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods.
+
+ These tests focus exclusively on serialization/deserialization logic to ensure that
+ VariablePool data can be properly serialized to dictionaries/JSON and reconstructed
+ while preserving all data integrity.
+ """
+
+ _NODE1_ID = "node_1"
+ _NODE2_ID = "node_2"
+ _NODE3_ID = "node_3"
+
+ def _create_pool_without_file(self):
+ # Create comprehensive system variables
+ system_vars = SystemVariable(
+ user_id="test_user_id",
+ app_id="test_app_id",
+ workflow_id="test_workflow_id",
+ workflow_execution_id="test_execution_123",
+ query="test query",
+ conversation_id="test_conv_id",
+ dialogue_count=5,
+ )
+
+ # Create environment variables with all types including ArrayFileVariable
+ env_vars: list[VariableUnion] = [
+ StringVariable(
+ id="env_string_id",
+ name="env_string",
+ value="env_string_value",
+ selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"],
+ ),
+ IntegerVariable(
+ id="env_integer_id",
+ name="env_integer",
+ value=1,
+ selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"],
+ ),
+ FloatVariable(
+ id="env_float_id",
+ name="env_float",
+ value=1.0,
+ selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"],
+ ),
+ ]
+
+ # Create conversation variables with complex data
+ conv_vars: list[VariableUnion] = [
+ StringVariable(
+ id="conv_string_id",
+ name="conv_string",
+ value="conv_string_value",
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"],
+ ),
+ IntegerVariable(
+ id="conv_integer_id",
+ name="conv_integer",
+ value=1,
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"],
+ ),
+ FloatVariable(
+ id="conv_float_id",
+ name="conv_float",
+ value=1.0,
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"],
+ ),
+ ObjectVariable(
+ id="conv_object_id",
+ name="conv_object",
+ value={"key": "value", "nested": {"data": 123}},
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"],
+ ),
+ ArrayStringVariable(
+ id="conv_array_string_id",
+ name="conv_array_string",
+ value=["conv_array_string_value"],
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"],
+ ),
+ ArrayNumberVariable(
+ id="conv_array_number_id",
+ name="conv_array_number",
+ value=[1, 1.0],
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"],
+ ),
+ ArrayObjectVariable(
+ id="conv_array_object_id",
+ name="conv_array_object",
+ value=[{"a": 1}, {"b": "2"}],
+ selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"],
+ ),
+ ]
+
+ # Create comprehensive user inputs
+ user_inputs = {
+ "string_input": "test_value",
+ "number_input": 42,
+ "object_input": {"nested": {"key": "value"}},
+ "array_input": ["item1", "item2", "item3"],
+ }
+
+ # Create VariablePool
+ pool = VariablePool(
+ system_variables=system_vars,
+ user_inputs=user_inputs,
+ environment_variables=env_vars,
+ conversation_variables=conv_vars,
+ )
+ return pool
+
+ def _add_node_data_to_pool(self, pool: VariablePool, with_file=False):
+ test_file = File(
+ tenant_id="test_tenant_id",
+ type=FileType.DOCUMENT,
+ transfer_method=FileTransferMethod.LOCAL_FILE,
+ related_id="test_related_id",
+ remote_url="test_url",
+ filename="test_file.txt",
+ storage_key="test_storage_key",
+ )
+
+ # Add various segment types to variable dictionary
+ pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string"))
+ pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123))
+ pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67))
+ pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"}))
+ if with_file:
+ pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file))
+ pool.add((self._NODE1_ID, "none_var"), NoneSegment())
+
+ # Add array segments including ArrayFileVariable
+ pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"]))
+ pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3]))
+ pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}]))
+ if with_file:
+ pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
+ pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
+
+ # Add nested variables
+ pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
+
+ def test_system_variables(self):
+ sys_vars = SystemVariable(
+ user_id="test_user_id",
+ app_id="test_app_id",
+ workflow_id="test_workflow_id",
+ workflow_execution_id="test_execution_123",
+ query="test query",
+ conversation_id="test_conv_id",
+ dialogue_count=5,
+ )
+ pool = VariablePool(system_variables=sys_vars)
+ json = pool.model_dump_json()
+ pool2 = VariablePool.model_validate_json(json)
+ assert pool2.system_variables == sys_vars
+
+ for mode in ["json", "python"]:
+ dict_ = pool.model_dump(mode=mode)
+ pool2 = VariablePool.model_validate(dict_)
+ assert pool2.system_variables == sys_vars
+
+ def test_pool_without_file_vars(self):
+ pool = self._create_pool_without_file()
+ json = pool.model_dump_json()
+ pool2 = pool.model_validate_json(json)
+ assert pool2.system_variables == pool.system_variables
+ assert pool2.conversation_variables == pool.conversation_variables
+ assert pool2.environment_variables == pool.environment_variables
+ assert pool2.user_inputs == pool.user_inputs
+ assert pool2.variable_dictionary == pool.variable_dictionary
+ assert pool2 == pool
+
+ def test_basic_dictionary_round_trip(self):
+ """Test basic round-trip serialization: model_dump() → model_validate()"""
+ # Create a comprehensive VariablePool with all data types
+ original_pool = self._create_pool_without_file()
+ self._add_node_data_to_pool(original_pool)
+
+ # Serialize to dictionary using Pydantic's model_dump()
+ serialized_data = original_pool.model_dump()
+
+ # Verify serialized data structure
+ assert isinstance(serialized_data, dict)
+ assert "system_variables" in serialized_data
+ assert "user_inputs" in serialized_data
+ assert "environment_variables" in serialized_data
+ assert "conversation_variables" in serialized_data
+ assert "variable_dictionary" in serialized_data
+
+ # Deserialize back using Pydantic's model_validate()
+ reconstructed_pool = VariablePool.model_validate(serialized_data)
+
+ # Verify data integrity is preserved
+ self._assert_pools_equal(original_pool, reconstructed_pool)
+
+ def test_json_round_trip(self):
+ """Test JSON round-trip serialization: model_dump_json() → model_validate_json()"""
+ # Create a comprehensive VariablePool with all data types
+ original_pool = self._create_pool_without_file()
+ self._add_node_data_to_pool(original_pool)
+
+ # Serialize to JSON string using Pydantic's model_dump_json()
+ json_data = original_pool.model_dump_json()
+
+ # Verify JSON is valid string
+ assert isinstance(json_data, str)
+ assert len(json_data) > 0
+
+ # Deserialize back using Pydantic's model_validate_json()
+ reconstructed_pool = VariablePool.model_validate_json(json_data)
+
+ # Verify data integrity is preserved
+ self._assert_pools_equal(original_pool, reconstructed_pool)
+
+ def test_complex_data_serialization(self):
+ """Test serialization of complex data structures including ArrayFileVariable"""
+ original_pool = self._create_pool_without_file()
+ self._add_node_data_to_pool(original_pool, with_file=True)
+
+ # Test dictionary round-trip
+ dict_data = original_pool.model_dump()
+ reconstructed_dict = VariablePool.model_validate(dict_data)
+
+ # Test JSON round-trip
+ json_data = original_pool.model_dump_json()
+ reconstructed_json = VariablePool.model_validate_json(json_data)
+
+ # Verify both reconstructed pools are equivalent
+ self._assert_pools_equal(reconstructed_dict, reconstructed_json)
+ # TODO: assert the data for file object...
+
+ def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
+ """Assert that two VariablePools contain equivalent data"""
+
+ # Compare system variables
+ assert pool1.system_variables == pool2.system_variables
+
+ # Compare user inputs
+ assert dict(pool1.user_inputs) == dict(pool2.user_inputs)
+
+ # Compare environment variables count
+ assert pool1.environment_variables == pool2.environment_variables
+
+ # Compare conversation variables count
+ assert pool1.conversation_variables == pool2.conversation_variables
+
+ # Test key variable retrievals to ensure functionality is preserved
+ test_selectors = [
+ (SYSTEM_VARIABLE_NODE_ID, "user_id"),
+ (SYSTEM_VARIABLE_NODE_ID, "app_id"),
+ (ENVIRONMENT_VARIABLE_NODE_ID, "env_string"),
+ (ENVIRONMENT_VARIABLE_NODE_ID, "env_number"),
+ (CONVERSATION_VARIABLE_NODE_ID, "conv_string"),
+ (self._NODE1_ID, "string_var"),
+ (self._NODE1_ID, "int_var"),
+ (self._NODE1_ID, "float_var"),
+ (self._NODE2_ID, "array_string"),
+ (self._NODE2_ID, "array_number"),
+ (self._NODE3_ID, "nested", "deep", "var"),
+ ]
+
+ for selector in test_selectors:
+ val1 = pool1.get(selector)
+ val2 = pool2.get(selector)
+
+ # Both should exist or both should be None
+ assert (val1 is None) == (val2 is None)
+
+ if val1 is not None and val2 is not None:
+ # Values should be equal
+ assert val1.value == val2.value
+ # Value types should be the same (more important than exact class type)
+ assert val1.value_type == val2.value_type
+
+ def test_variable_pool_deserialization_default_dict(self):
+ variable_pool = VariablePool(
+ user_inputs={"a": 1, "b": "2"},
+ system_variables=SystemVariable(workflow_id=str(uuid.uuid4())),
+ environment_variables=[
+ StringVariable(name="str_var", value="a"),
+ ],
+ conversation_variables=[IntegerVariable(name="int_var", value=1)],
+ )
+ assert isinstance(variable_pool.variable_dictionary, defaultdict)
+ json = variable_pool.model_dump_json()
+ loaded = VariablePool.model_validate_json(json)
+ assert isinstance(loaded.variable_dictionary, defaultdict)
+
+ loaded.add(["non_exist_node", "a"], 1)
+
+ pool_dict = variable_pool.model_dump()
+ loaded = VariablePool.model_validate(pool_dict)
+ assert isinstance(loaded.variable_dictionary, defaultdict)
+ loaded.add(["non_exist_node", "a"], 1)
diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
index 646de8bf3a..4866db1fdb 100644
--- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
+++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py
@@ -18,10 +18,10 @@ from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
-from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from models.enums import CreatorUserRole
from models.model import AppMode
@@ -67,28 +67,25 @@ def real_app_generate_entity():
@pytest.fixture
def real_workflow_system_variables():
- return {
- SystemVariableKey.QUERY: "test query",
- SystemVariableKey.CONVERSATION_ID: "test-conversation-id",
- SystemVariableKey.USER_ID: "test-user-id",
- SystemVariableKey.APP_ID: "test-app-id",
- SystemVariableKey.WORKFLOW_ID: "test-workflow-id",
- SystemVariableKey.WORKFLOW_EXECUTION_ID: "test-workflow-run-id",
- }
+ return SystemVariable(
+ query="test query",
+ conversation_id="test-conversation-id",
+ user_id="test-user-id",
+ app_id="test-app-id",
+ workflow_id="test-workflow-id",
+ workflow_execution_id="test-workflow-run-id",
+ )
@pytest.fixture
def mock_node_execution_repository():
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
- repo.get_by_node_execution_id.return_value = None
- repo.get_running_executions.return_value = []
return repo
@pytest.fixture
def mock_workflow_execution_repository():
repo = MagicMock(spec=WorkflowExecutionRepository)
- repo.get.return_value = None
return repo
@@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_success(
@@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
- # Mock get_running_executions to return an empty list
- workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
+ # No running node executions in cache (empty cache)
# Call the method
result = workflow_cycle_manager.handle_workflow_run_failed(
@@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Create a mock event
event = MagicMock(spec=QueueNodeStartedEvent)
@@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock the repository get method to return the real execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
# Call the method
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
@@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
# Verify the result
assert result == workflow_execution
- # Test error case
- workflow_cycle_manager._workflow_execution_repository.get.return_value = None
+ # Test error case - clear cache
+ workflow_cycle_manager._workflow_execution_cache.clear()
# Expect an error when execution is not found
- with pytest.raises(ValueError):
+ from core.app.task_pipeline.exc import WorkflowRunNotFoundError
+
+ with pytest.raises(WorkflowRunNotFoundError):
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
@@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock the repository to return the node execution
- workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
+ # Pre-populate the cache with the node execution
+ workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_success(
@@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
started_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
- workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
+ # Pre-populate the cache with the workflow execution
+ workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_partial_success(
@@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
- # Mock the repository to return the node execution
- workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
+ # Pre-populate the cache with the node execution
+ workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py
index f1cb937bb3..54bf6558bf 100644
--- a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py
+++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py
@@ -10,7 +10,7 @@ class TestAppendVariablesRecursively:
def test_append_simple_dict_value(self):
"""Test appending a simple dictionary value"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["output"]
variable_value = {"name": "John", "age": 30}
@@ -33,7 +33,7 @@ class TestAppendVariablesRecursively:
def test_append_object_segment_value(self):
"""Test appending an ObjectSegment value"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["result"]
@@ -60,7 +60,7 @@ class TestAppendVariablesRecursively:
def test_append_nested_dict_value(self):
"""Test appending a nested dictionary value"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["data"]
@@ -97,7 +97,7 @@ class TestAppendVariablesRecursively:
def test_append_non_dict_value(self):
"""Test appending a non-dictionary value (should not recurse)"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["simple"]
variable_value = "simple_string"
@@ -114,7 +114,7 @@ class TestAppendVariablesRecursively:
def test_append_segment_non_object_value(self):
"""Test appending a Segment that is not ObjectSegment (should not recurse)"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["text"]
variable_value = StringSegment(value="Hello World")
@@ -132,7 +132,7 @@ class TestAppendVariablesRecursively:
def test_append_empty_dict_value(self):
"""Test appending an empty dictionary value"""
- pool = VariablePool()
+ pool = VariablePool.empty()
node_id = "test_node"
variable_key_list = ["empty"]
variable_value: dict[str, Any] = {}
diff --git a/api/tests/unit_tests/extensions/test_redis.py b/api/tests/unit_tests/extensions/test_redis.py
new file mode 100644
index 0000000000..933fa32894
--- /dev/null
+++ b/api/tests/unit_tests/extensions/test_redis.py
@@ -0,0 +1,53 @@
+from redis import RedisError
+
+from extensions.ext_redis import redis_fallback
+
+
+def test_redis_fallback_success():
+ @redis_fallback(default_return=None)
+ def test_func():
+ return "success"
+
+ assert test_func() == "success"
+
+
+def test_redis_fallback_error():
+ @redis_fallback(default_return="fallback")
+ def test_func():
+ raise RedisError("Redis error")
+
+ assert test_func() == "fallback"
+
+
+def test_redis_fallback_none_default():
+ @redis_fallback()
+ def test_func():
+ raise RedisError("Redis error")
+
+ assert test_func() is None
+
+
+def test_redis_fallback_with_args():
+ @redis_fallback(default_return=0)
+ def test_func(x, y):
+ raise RedisError("Redis error")
+
+ assert test_func(1, 2) == 0
+
+
+def test_redis_fallback_with_kwargs():
+ @redis_fallback(default_return={})
+ def test_func(x=None, y=None):
+ raise RedisError("Redis error")
+
+ assert test_func(x=1, y=2) == {}
+
+
+def test_redis_fallback_preserves_function_metadata():
+ @redis_fallback(default_return=None)
+ def test_func():
+ """Test function docstring"""
+ pass
+
+ assert test_func.__name__ == "test_func"
+ assert test_func.__doc__ == "Test function docstring"
diff --git a/api/tests/unit_tests/factories/test_build_from_mapping.py b/api/tests/unit_tests/factories/test_build_from_mapping.py
index 48463a369e..d42c4412f5 100644
--- a/api/tests/unit_tests/factories/test_build_from_mapping.py
+++ b/api/tests/unit_tests/factories/test_build_from_mapping.py
@@ -54,8 +54,7 @@ def mock_tool_file():
mock.mimetype = "application/pdf"
mock.original_url = "http://example.com/tool.pdf"
mock.size = 2048
- with patch("factories.file_factory.db.session.query") as mock_query:
- mock_query.return_value.filter.return_value.first.return_value = mock
+ with patch("factories.file_factory.db.session.scalar", return_value=mock):
yield mock
@@ -153,8 +152,7 @@ def test_build_from_remote_url(mock_http_head):
def test_tool_file_not_found():
"""Test ToolFile not found in database."""
- with patch("factories.file_factory.db.session.query") as mock_query:
- mock_query.return_value.filter.return_value.first.return_value = None
+ with patch("factories.file_factory.db.session.scalar", return_value=None):
mapping = tool_file_mapping()
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py
index 481fbdc91a..4f2542a323 100644
--- a/api/tests/unit_tests/factories/test_variable_factory.py
+++ b/api/tests/unit_tests/factories/test_variable_factory.py
@@ -14,9 +14,7 @@ from core.variables import (
ArrayStringVariable,
FloatVariable,
IntegerVariable,
- ObjectSegment,
SecretVariable,
- SegmentType,
StringVariable,
)
from core.variables.exc import VariableError
@@ -418,8 +416,6 @@ def test_build_segment_file_array_with_different_file_types():
@st.composite
def _generate_file(draw) -> File:
- file_id = draw(st.text(min_size=1, max_size=10))
- tenant_id = draw(st.text(min_size=1, max_size=10))
file_type, mime_type, extension = draw(
st.sampled_from(
[
@@ -509,8 +505,8 @@ def test_build_segment_type_for_scalar():
size=1000,
)
cases = [
- TestCase(0, SegmentType.NUMBER),
- TestCase(0.0, SegmentType.NUMBER),
+ TestCase(0, SegmentType.INTEGER),
+ TestCase(0.0, SegmentType.FLOAT),
TestCase("", SegmentType.STRING),
TestCase(file, SegmentType.FILE),
]
@@ -535,14 +531,14 @@ class TestBuildSegmentWithType:
result = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result, IntegerSegment)
assert result.value == 42
- assert result.value_type == SegmentType.NUMBER
+ assert result.value_type == SegmentType.INTEGER
def test_number_type_float(self):
"""Test building a number segment with float value."""
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result, FloatSegment)
assert result.value == 3.14
- assert result.value_type == SegmentType.NUMBER
+ assert result.value_type == SegmentType.FLOAT
def test_object_type(self):
"""Test building an object segment with correct type."""
@@ -656,14 +652,14 @@ class TestBuildSegmentWithType:
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, None)
- assert "Expected string, but got None" in str(exc_info.value)
+ assert "expected string, but got None" in str(exc_info.value)
def test_type_mismatch_empty_list_to_non_array(self):
"""Test type mismatch when expecting non-array type but getting empty list."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, [])
- assert "Expected string, but got empty list" in str(exc_info.value)
+ assert "expected string, but got empty list" in str(exc_info.value)
def test_type_mismatch_object_to_array(self):
"""Test type mismatch when expecting array but getting object."""
@@ -678,19 +674,19 @@ class TestBuildSegmentWithType:
# Integer should work
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result_int, IntegerSegment)
- assert result_int.value_type == SegmentType.NUMBER
+ assert result_int.value_type == SegmentType.INTEGER
# Float should work
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result_float, FloatSegment)
- assert result_float.value_type == SegmentType.NUMBER
+ assert result_float.value_type == SegmentType.FLOAT
@pytest.mark.parametrize(
("segment_type", "value", "expected_class"),
[
(SegmentType.STRING, "test", StringSegment),
- (SegmentType.NUMBER, 42, IntegerSegment),
- (SegmentType.NUMBER, 3.14, FloatSegment),
+ (SegmentType.INTEGER, 42, IntegerSegment),
+ (SegmentType.FLOAT, 3.14, FloatSegment),
(SegmentType.OBJECT, {}, ObjectSegment),
(SegmentType.NONE, None, NoneSegment),
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
@@ -861,5 +857,5 @@ class TestBuildSegmentValueErrors:
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
- assert true_segment.value_type == SegmentType.NUMBER
- assert false_segment.value_type == SegmentType.NUMBER
+ assert true_segment.value_type == SegmentType.INTEGER
+ assert false_segment.value_type == SegmentType.INTEGER
diff --git a/api/tests/unit_tests/libs/test_email_i18n.py b/api/tests/unit_tests/libs/test_email_i18n.py
new file mode 100644
index 0000000000..aeb30438e0
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_email_i18n.py
@@ -0,0 +1,539 @@
+"""
+Unit tests for EmailI18nService
+
+Tests the email internationalization service with mocked dependencies
+following Domain-Driven Design principles.
+"""
+
+from typing import Any
+from unittest.mock import MagicMock
+
+import pytest
+
+from libs.email_i18n import (
+ EmailI18nConfig,
+ EmailI18nService,
+ EmailLanguage,
+ EmailTemplate,
+ EmailType,
+ FlaskEmailRenderer,
+ FlaskMailSender,
+ create_default_email_config,
+ get_email_i18n_service,
+)
+from services.feature_service import BrandingModel
+
+
+class MockEmailRenderer:
+ """Mock implementation of EmailRenderer protocol"""
+
+ def __init__(self) -> None:
+ self.rendered_templates: list[tuple[str, dict[str, Any]]] = []
+
+ def render_template(self, template_path: str, **context: Any) -> str:
+ """Mock render_template that returns a formatted string"""
+ self.rendered_templates.append((template_path, context))
+ return f"Rendered {template_path} with {context}"
+
+
+class MockBrandingService:
+ """Mock implementation of BrandingService protocol"""
+
+ def __init__(self, enabled: bool = False, application_title: str = "Dify") -> None:
+ self.enabled = enabled
+ self.application_title = application_title
+
+ def get_branding_config(self) -> BrandingModel:
+ """Return mock branding configuration"""
+ branding_model = MagicMock(spec=BrandingModel)
+ branding_model.enabled = self.enabled
+ branding_model.application_title = self.application_title
+ return branding_model
+
+
+class MockEmailSender:
+ """Mock implementation of EmailSender protocol"""
+
+ def __init__(self) -> None:
+ self.sent_emails: list[dict[str, str]] = []
+
+ def send_email(self, to: str, subject: str, html_content: str) -> None:
+ """Mock send_email that records sent emails"""
+ self.sent_emails.append(
+ {
+ "to": to,
+ "subject": subject,
+ "html_content": html_content,
+ }
+ )
+
+
+class TestEmailI18nService:
+ """Test cases for EmailI18nService"""
+
+ @pytest.fixture
+ def email_config(self) -> EmailI18nConfig:
+ """Create test email configuration"""
+ return EmailI18nConfig(
+ templates={
+ EmailType.RESET_PASSWORD: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Reset Your {application_title} Password",
+ template_path="reset_password_en.html",
+ branded_template_path="branded/reset_password_en.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="重置您的 {application_title} 密码",
+ template_path="reset_password_zh.html",
+ branded_template_path="branded/reset_password_zh.html",
+ ),
+ },
+ EmailType.INVITE_MEMBER: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Join {application_title} Workspace",
+ template_path="invite_member_en.html",
+ branded_template_path="branded/invite_member_en.html",
+ ),
+ },
+ }
+ )
+
+ @pytest.fixture
+ def mock_renderer(self) -> MockEmailRenderer:
+ """Create mock email renderer"""
+ return MockEmailRenderer()
+
+ @pytest.fixture
+ def mock_branding_service(self) -> MockBrandingService:
+ """Create mock branding service"""
+ return MockBrandingService()
+
+ @pytest.fixture
+ def mock_sender(self) -> MockEmailSender:
+ """Create mock email sender"""
+ return MockEmailSender()
+
+ @pytest.fixture
+ def email_service(
+ self,
+ email_config: EmailI18nConfig,
+ mock_renderer: MockEmailRenderer,
+ mock_branding_service: MockBrandingService,
+ mock_sender: MockEmailSender,
+ ) -> EmailI18nService:
+ """Create EmailI18nService with mocked dependencies"""
+ return EmailI18nService(
+ config=email_config,
+ renderer=mock_renderer,
+ branding_service=mock_branding_service,
+ sender=mock_sender,
+ )
+
+ def test_send_email_with_english_language(
+ self,
+ email_service: EmailI18nService,
+ mock_renderer: MockEmailRenderer,
+ mock_sender: MockEmailSender,
+ ) -> None:
+ """Test sending email with English language"""
+ email_service.send_email(
+ email_type=EmailType.RESET_PASSWORD,
+ language_code="en-US",
+ to="test@example.com",
+ template_context={"reset_link": "https://example.com/reset"},
+ )
+
+ # Verify renderer was called with correct template
+ assert len(mock_renderer.rendered_templates) == 1
+ template_path, context = mock_renderer.rendered_templates[0]
+ assert template_path == "reset_password_en.html"
+ assert context["reset_link"] == "https://example.com/reset"
+ assert context["branding_enabled"] is False
+ assert context["application_title"] == "Dify"
+
+ # Verify email was sent
+ assert len(mock_sender.sent_emails) == 1
+ sent_email = mock_sender.sent_emails[0]
+ assert sent_email["to"] == "test@example.com"
+ assert sent_email["subject"] == "Reset Your Dify Password"
+ assert "reset_password_en.html" in sent_email["html_content"]
+
+ def test_send_email_with_chinese_language(
+ self,
+ email_service: EmailI18nService,
+ mock_sender: MockEmailSender,
+ ) -> None:
+ """Test sending email with Chinese language"""
+ email_service.send_email(
+ email_type=EmailType.RESET_PASSWORD,
+ language_code="zh-Hans",
+ to="test@example.com",
+ template_context={"reset_link": "https://example.com/reset"},
+ )
+
+ # Verify email was sent with Chinese subject
+ assert len(mock_sender.sent_emails) == 1
+ sent_email = mock_sender.sent_emails[0]
+ assert sent_email["subject"] == "重置您的 Dify 密码"
+
+ def test_send_email_with_branding_enabled(
+ self,
+ email_config: EmailI18nConfig,
+ mock_renderer: MockEmailRenderer,
+ mock_sender: MockEmailSender,
+ ) -> None:
+ """Test sending email with branding enabled"""
+ # Create branding service with branding enabled
+ branding_service = MockBrandingService(enabled=True, application_title="MyApp")
+
+ email_service = EmailI18nService(
+ config=email_config,
+ renderer=mock_renderer,
+ branding_service=branding_service,
+ sender=mock_sender,
+ )
+
+ email_service.send_email(
+ email_type=EmailType.RESET_PASSWORD,
+ language_code="en-US",
+ to="test@example.com",
+ )
+
+ # Verify branded template was used
+ assert len(mock_renderer.rendered_templates) == 1
+ template_path, context = mock_renderer.rendered_templates[0]
+ assert template_path == "branded/reset_password_en.html"
+ assert context["branding_enabled"] is True
+ assert context["application_title"] == "MyApp"
+
+ # Verify subject includes custom application title
+ assert len(mock_sender.sent_emails) == 1
+ sent_email = mock_sender.sent_emails[0]
+ assert sent_email["subject"] == "Reset Your MyApp Password"
+
+ def test_send_email_with_language_fallback(
+ self,
+ email_service: EmailI18nService,
+ mock_sender: MockEmailSender,
+ ) -> None:
+ """Test language fallback to English when requested language not available"""
+ # Request invite member in Chinese (not configured)
+ email_service.send_email(
+ email_type=EmailType.INVITE_MEMBER,
+ language_code="zh-Hans",
+ to="test@example.com",
+ )
+
+ # Should fall back to English
+ assert len(mock_sender.sent_emails) == 1
+ sent_email = mock_sender.sent_emails[0]
+ assert sent_email["subject"] == "Join Dify Workspace"
+
+ def test_send_email_with_unknown_language_code(
+ self,
+ email_service: EmailI18nService,
+ mock_sender: MockEmailSender,
+ ) -> None:
+ """Test unknown language code falls back to English"""
+ email_service.send_email(
+ email_type=EmailType.RESET_PASSWORD,
+ language_code="fr-FR", # French not configured
+ to="test@example.com",
+ )
+
+ # Should use English
+ assert len(mock_sender.sent_emails) == 1
+ sent_email = mock_sender.sent_emails[0]
+ assert sent_email["subject"] == "Reset Your Dify Password"
+
+ def test_send_change_email_old_phase(
+ self,
+ email_config: EmailI18nConfig,
+ mock_renderer: MockEmailRenderer,
+ mock_sender: MockEmailSender,
+ mock_branding_service: MockBrandingService,
+ ) -> None:
+ """Test sending change email for old email verification"""
+ # Add change email templates to config
+ email_config.templates[EmailType.CHANGE_EMAIL_OLD] = {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Verify your current email",
+ template_path="change_email_old_en.html",
+ branded_template_path="branded/change_email_old_en.html",
+ ),
+ }
+
+ email_service = EmailI18nService(
+ config=email_config,
+ renderer=mock_renderer,
+ branding_service=mock_branding_service,
+ sender=mock_sender,
+ )
+
+ email_service.send_change_email(
+ language_code="en-US",
+ to="old@example.com",
+ code="123456",
+ phase="old_email",
+ )
+
+ # Verify correct template and context
+ assert len(mock_renderer.rendered_templates) == 1
+ template_path, context = mock_renderer.rendered_templates[0]
+ assert template_path == "change_email_old_en.html"
+ assert context["to"] == "old@example.com"
+ assert context["code"] == "123456"
+
+ def test_send_change_email_new_phase(
+ self,
+ email_config: EmailI18nConfig,
+ mock_renderer: MockEmailRenderer,
+ mock_sender: MockEmailSender,
+ mock_branding_service: MockBrandingService,
+ ) -> None:
+ """Test sending change email for new email verification"""
+ # Add change email templates to config
+ email_config.templates[EmailType.CHANGE_EMAIL_NEW] = {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Verify your new email",
+ template_path="change_email_new_en.html",
+ branded_template_path="branded/change_email_new_en.html",
+ ),
+ }
+
+ email_service = EmailI18nService(
+ config=email_config,
+ renderer=mock_renderer,
+ branding_service=mock_branding_service,
+ sender=mock_sender,
+ )
+
+ email_service.send_change_email(
+ language_code="en-US",
+ to="new@example.com",
+ code="654321",
+ phase="new_email",
+ )
+
+ # Verify correct template and context
+ assert len(mock_renderer.rendered_templates) == 1
+ template_path, context = mock_renderer.rendered_templates[0]
+ assert template_path == "change_email_new_en.html"
+ assert context["to"] == "new@example.com"
+ assert context["code"] == "654321"
+
+ def test_send_change_email_invalid_phase(
+ self,
+ email_service: EmailI18nService,
+ ) -> None:
+ """Test sending change email with invalid phase raises error"""
+ with pytest.raises(ValueError, match="Invalid phase: invalid_phase"):
+ email_service.send_change_email(
+ language_code="en-US",
+ to="test@example.com",
+ code="123456",
+ phase="invalid_phase",
+ )
+
+ def test_send_raw_email_single_recipient(
+ self,
+ email_service: EmailI18nService,
+ mock_sender: MockEmailSender,
+ ) -> None:
+ """Test sending raw email to single recipient"""
+ email_service.send_raw_email(
+ to="test@example.com",
+ subject="Test Subject",
+ html_content="Test Content",
+ )
+
+ assert len(mock_sender.sent_emails) == 1
+ sent_email = mock_sender.sent_emails[0]
+ assert sent_email["to"] == "test@example.com"
+ assert sent_email["subject"] == "Test Subject"
+ assert sent_email["html_content"] == "Test Content"
+
+ def test_send_raw_email_multiple_recipients(
+ self,
+ email_service: EmailI18nService,
+ mock_sender: MockEmailSender,
+ ) -> None:
+ """Test sending raw email to multiple recipients"""
+ recipients = ["user1@example.com", "user2@example.com", "user3@example.com"]
+
+ email_service.send_raw_email(
+ to=recipients,
+ subject="Test Subject",
+ html_content="Test Content",
+ )
+
+ # Should send individual emails to each recipient
+ assert len(mock_sender.sent_emails) == 3
+ for i, recipient in enumerate(recipients):
+ sent_email = mock_sender.sent_emails[i]
+ assert sent_email["to"] == recipient
+ assert sent_email["subject"] == "Test Subject"
+ assert sent_email["html_content"] == "Test Content"
+
+ def test_get_template_missing_email_type(
+ self,
+ email_config: EmailI18nConfig,
+ ) -> None:
+ """Test getting template for missing email type raises error"""
+ with pytest.raises(ValueError, match="No templates configured for email type"):
+ email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
+
+ def test_get_template_missing_language_and_english(
+ self,
+ email_config: EmailI18nConfig,
+ ) -> None:
+ """Test error when neither requested language nor English fallback exists"""
+ # Add template without English fallback
+ email_config.templates[EmailType.EMAIL_CODE_LOGIN] = {
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="Test",
+ template_path="test.html",
+ branded_template_path="branded/test.html",
+ ),
+ }
+
+ with pytest.raises(ValueError, match="No template found for"):
+ # Request a language that doesn't exist and no English fallback
+ email_config.get_template(EmailType.EMAIL_CODE_LOGIN, EmailLanguage.EN_US)
+
+ def test_subject_templating_with_variables(
+ self,
+ email_config: EmailI18nConfig,
+ mock_renderer: MockEmailRenderer,
+ mock_sender: MockEmailSender,
+ mock_branding_service: MockBrandingService,
+ ) -> None:
+ """Test subject templating with custom variables"""
+ # Add template with variable in subject
+ email_config.templates[EmailType.OWNER_TRANSFER_NEW_NOTIFY] = {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="You are now the owner of {WorkspaceName}",
+ template_path="owner_transfer_en.html",
+ branded_template_path="branded/owner_transfer_en.html",
+ ),
+ }
+
+ email_service = EmailI18nService(
+ config=email_config,
+ renderer=mock_renderer,
+ branding_service=mock_branding_service,
+ sender=mock_sender,
+ )
+
+ email_service.send_email(
+ email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY,
+ language_code="en-US",
+ to="test@example.com",
+ template_context={"WorkspaceName": "My Workspace"},
+ )
+
+ # Verify subject was templated correctly
+ assert len(mock_sender.sent_emails) == 1
+ sent_email = mock_sender.sent_emails[0]
+ assert sent_email["subject"] == "You are now the owner of My Workspace"
+
+ def test_email_language_from_language_code(self) -> None:
+ """Test EmailLanguage.from_language_code method"""
+ assert EmailLanguage.from_language_code("zh-Hans") == EmailLanguage.ZH_HANS
+ assert EmailLanguage.from_language_code("en-US") == EmailLanguage.EN_US
+ assert EmailLanguage.from_language_code("fr-FR") == EmailLanguage.EN_US # Fallback
+ assert EmailLanguage.from_language_code("unknown") == EmailLanguage.EN_US # Fallback
+
+
+class TestEmailI18nIntegration:
+ """Integration tests for email i18n components"""
+
+ def test_create_default_email_config(self) -> None:
+ """Test creating default email configuration"""
+ config = create_default_email_config()
+
+ # Verify key email types have at least English template
+ expected_types = [
+ EmailType.RESET_PASSWORD,
+ EmailType.INVITE_MEMBER,
+ EmailType.EMAIL_CODE_LOGIN,
+ EmailType.CHANGE_EMAIL_OLD,
+ EmailType.CHANGE_EMAIL_NEW,
+ EmailType.OWNER_TRANSFER_CONFIRM,
+ EmailType.OWNER_TRANSFER_OLD_NOTIFY,
+ EmailType.OWNER_TRANSFER_NEW_NOTIFY,
+ EmailType.ACCOUNT_DELETION_SUCCESS,
+ EmailType.ACCOUNT_DELETION_VERIFICATION,
+ EmailType.QUEUE_MONITOR_ALERT,
+ EmailType.DOCUMENT_CLEAN_NOTIFY,
+ ]
+
+ for email_type in expected_types:
+ assert email_type in config.templates
+ assert EmailLanguage.EN_US in config.templates[email_type]
+
+ # Verify some have Chinese translations
+ assert EmailLanguage.ZH_HANS in config.templates[EmailType.RESET_PASSWORD]
+ assert EmailLanguage.ZH_HANS in config.templates[EmailType.INVITE_MEMBER]
+
+ def test_get_email_i18n_service(self) -> None:
+ """Test getting global email i18n service instance"""
+ service1 = get_email_i18n_service()
+ service2 = get_email_i18n_service()
+
+ # Should return the same instance
+ assert service1 is service2
+
+ def test_flask_email_renderer(self) -> None:
+ """Test FlaskEmailRenderer implementation"""
+ renderer = FlaskEmailRenderer()
+
+ # Should raise TemplateNotFound when template doesn't exist
+ from jinja2.exceptions import TemplateNotFound
+
+ with pytest.raises(TemplateNotFound):
+ renderer.render_template("test.html", foo="bar")
+
+ def test_flask_mail_sender_not_initialized(self) -> None:
+ """Test FlaskMailSender when mail is not initialized"""
+ sender = FlaskMailSender()
+
+ # Mock mail.is_inited() to return False
+ import libs.email_i18n
+
+ original_mail = libs.email_i18n.mail
+ mock_mail = MagicMock()
+ mock_mail.is_inited.return_value = False
+ libs.email_i18n.mail = mock_mail
+
+ try:
+ # Should not send email when mail is not initialized
+ sender.send_email("test@example.com", "Subject", "Content")
+ mock_mail.send.assert_not_called()
+ finally:
+ # Restore original mail
+ libs.email_i18n.mail = original_mail
+
+ def test_flask_mail_sender_initialized(self) -> None:
+ """Test FlaskMailSender when mail is initialized"""
+ sender = FlaskMailSender()
+
+ # Mock mail.is_inited() to return True
+ import libs.email_i18n
+
+ original_mail = libs.email_i18n.mail
+ mock_mail = MagicMock()
+ mock_mail.is_inited.return_value = True
+ libs.email_i18n.mail = mock_mail
+
+ try:
+ # Should send email when mail is initialized
+ sender.send_email("test@example.com", "Subject", "Content")
+ mock_mail.send.assert_called_once_with(
+ to="test@example.com",
+ subject="Subject",
+ html="Content",
+ )
+ finally:
+ # Restore original mail
+ libs.email_i18n.mail = original_mail
diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py
new file mode 100644
index 0000000000..b7701055f5
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_helper.py
@@ -0,0 +1,65 @@
+import pytest
+
+from libs.helper import extract_tenant_id
+from models.account import Account
+from models.model import EndUser
+
+
+class TestExtractTenantId:
+ """Test cases for the extract_tenant_id utility function."""
+
+ def test_extract_tenant_id_from_account_with_tenant(self):
+ """Test extracting tenant_id from Account with current_tenant_id."""
+ # Create a mock Account object
+ account = Account()
+ # Mock the current_tenant_id property
+ account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
+
+ tenant_id = extract_tenant_id(account)
+ assert tenant_id == "account-tenant-123"
+
+ def test_extract_tenant_id_from_account_without_tenant(self):
+ """Test extracting tenant_id from Account without current_tenant_id."""
+ # Create a mock Account object
+ account = Account()
+ account._current_tenant = None
+
+ tenant_id = extract_tenant_id(account)
+ assert tenant_id is None
+
+ def test_extract_tenant_id_from_enduser_with_tenant(self):
+ """Test extracting tenant_id from EndUser with tenant_id."""
+ # Create a mock EndUser object
+ end_user = EndUser()
+ end_user.tenant_id = "enduser-tenant-456"
+
+ tenant_id = extract_tenant_id(end_user)
+ assert tenant_id == "enduser-tenant-456"
+
+ def test_extract_tenant_id_from_enduser_without_tenant(self):
+ """Test extracting tenant_id from EndUser without tenant_id."""
+ # Create a mock EndUser object
+ end_user = EndUser()
+ end_user.tenant_id = None
+
+ tenant_id = extract_tenant_id(end_user)
+ assert tenant_id is None
+
+ def test_extract_tenant_id_with_invalid_user_type(self):
+ """Test extracting tenant_id with invalid user type raises ValueError."""
+ invalid_user = "not_a_user_object"
+
+ with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
+ extract_tenant_id(invalid_user)
+
+ def test_extract_tenant_id_with_none_user(self):
+ """Test extracting tenant_id with None user raises ValueError."""
+ with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
+ extract_tenant_id(None)
+
+ def test_extract_tenant_id_with_dict_user(self):
+ """Test extracting tenant_id with dict user raises ValueError."""
+ dict_user = {"id": "123", "tenant_id": "456"}
+
+ with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
+ extract_tenant_id(dict_user)
diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py
new file mode 100644
index 0000000000..39671077d4
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_login.py
@@ -0,0 +1,232 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+from flask import Flask, g
+from flask_login import LoginManager, UserMixin
+
+from libs.login import _get_user, current_user, login_required
+
+
+class MockUser(UserMixin):
+ """Mock user class for testing."""
+
+ def __init__(self, id: str, is_authenticated: bool = True):
+ self.id = id
+ self._is_authenticated = is_authenticated
+
+ @property
+ def is_authenticated(self):
+ return self._is_authenticated
+
+
+class TestLoginRequired:
+ """Test cases for login_required decorator."""
+
+ @pytest.fixture
+ def setup_app(self, app: Flask):
+ """Set up Flask app with login manager."""
+ # Initialize login manager
+ login_manager = LoginManager()
+ login_manager.init_app(app)
+
+ # Mock unauthorized handler
+ login_manager.unauthorized = MagicMock(return_value="Unauthorized")
+
+ # Add a dummy user loader to prevent exceptions
+ @login_manager.user_loader
+ def load_user(user_id):
+ return None
+
+ return app
+
+ def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
+ """Test that authenticated users can access protected views."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ with setup_app.test_request_context():
+ # Mock authenticated user
+ mock_user = MockUser("test_user", is_authenticated=True)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Protected content"
+
+ def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
+ """Test that unauthenticated users are redirected."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ with setup_app.test_request_context():
+ # Mock unauthenticated user
+ mock_user = MockUser("test_user", is_authenticated=False)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Unauthorized"
+ setup_app.login_manager.unauthorized.assert_called_once()
+
+ def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
+ """Test that LOGIN_DISABLED config bypasses authentication."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ with setup_app.test_request_context():
+ # Mock unauthenticated user and LOGIN_DISABLED
+ mock_user = MockUser("test_user", is_authenticated=False)
+ with patch("libs.login._get_user", return_value=mock_user):
+ with patch("libs.login.dify_config") as mock_config:
+ mock_config.LOGIN_DISABLED = True
+
+ result = protected_view()
+ assert result == "Protected content"
+ # Ensure unauthorized was not called
+ setup_app.login_manager.unauthorized.assert_not_called()
+
+ def test_options_request_bypasses_authentication(self, setup_app: Flask):
+ """Test that OPTIONS requests are exempt from authentication."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ with setup_app.test_request_context(method="OPTIONS"):
+ # Mock unauthenticated user
+ mock_user = MockUser("test_user", is_authenticated=False)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Protected content"
+ # Ensure unauthorized was not called
+ setup_app.login_manager.unauthorized.assert_not_called()
+
+ def test_flask_2_compatibility(self, setup_app: Flask):
+ """Test Flask 2.x compatibility with ensure_sync."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ # Mock Flask 2.x ensure_sync
+ setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
+
+ with setup_app.test_request_context():
+ mock_user = MockUser("test_user", is_authenticated=True)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Synced content"
+ setup_app.ensure_sync.assert_called_once()
+
+ def test_flask_1_compatibility(self, setup_app: Flask):
+ """Test Flask 1.x compatibility without ensure_sync."""
+
+ @login_required
+ def protected_view():
+ return "Protected content"
+
+ # Remove ensure_sync to simulate Flask 1.x
+ if hasattr(setup_app, "ensure_sync"):
+ delattr(setup_app, "ensure_sync")
+
+ with setup_app.test_request_context():
+ mock_user = MockUser("test_user", is_authenticated=True)
+ with patch("libs.login._get_user", return_value=mock_user):
+ result = protected_view()
+ assert result == "Protected content"
+
+
+class TestGetUser:
+ """Test cases for _get_user function."""
+
+ def test_get_user_returns_user_from_g(self, app: Flask):
+ """Test that _get_user returns user from g._login_user."""
+ mock_user = MockUser("test_user")
+
+ with app.test_request_context():
+ g._login_user = mock_user
+ user = _get_user()
+ assert user == mock_user
+ assert user.id == "test_user"
+
+ def test_get_user_loads_user_if_not_in_g(self, app: Flask):
+ """Test that _get_user loads user if not already in g."""
+ mock_user = MockUser("test_user")
+
+ # Mock login manager
+ login_manager = MagicMock()
+ login_manager._load_user = MagicMock()
+ app.login_manager = login_manager
+
+ with app.test_request_context():
+ # Simulate _load_user setting g._login_user
+ def side_effect():
+ g._login_user = mock_user
+
+ login_manager._load_user.side_effect = side_effect
+
+ user = _get_user()
+ assert user == mock_user
+ login_manager._load_user.assert_called_once()
+
+ def test_get_user_returns_none_without_request_context(self, app: Flask):
+ """Test that _get_user returns None outside request context."""
+ # Outside of request context
+ user = _get_user()
+ assert user is None
+
+
+class TestCurrentUser:
+ """Test cases for current_user proxy."""
+
+ def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
+ """Test that current_user proxy returns authenticated user."""
+ mock_user = MockUser("test_user", is_authenticated=True)
+
+ with app.test_request_context():
+ with patch("libs.login._get_user", return_value=mock_user):
+ assert current_user.id == "test_user"
+ assert current_user.is_authenticated is True
+
+ def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
+ """Test that current_user proxy handles None user."""
+ with app.test_request_context():
+ with patch("libs.login._get_user", return_value=None):
+ # When _get_user returns None, accessing attributes should fail
+ # or current_user should evaluate to falsy
+ try:
+ # Try to access an attribute that would exist on a real user
+ _ = current_user.id
+ pytest.fail("Should have raised AttributeError")
+ except AttributeError:
+ # This is expected when current_user is None
+ pass
+
+ def test_current_user_proxy_thread_safety(self, app: Flask):
+ """Test that current_user proxy is thread-safe."""
+ import threading
+
+ results = {}
+
+ def check_user_in_thread(user_id: str, index: int):
+ with app.test_request_context():
+ mock_user = MockUser(user_id)
+ with patch("libs.login._get_user", return_value=mock_user):
+ results[index] = current_user.id
+
+ # Create multiple threads with different users
+ threads = []
+ for i in range(5):
+ thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ # Verify each thread got its own user
+ for i in range(5):
+ assert results[i] == f"user_{i}"
diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py
new file mode 100644
index 0000000000..629d15b81a
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_oauth_clients.py
@@ -0,0 +1,249 @@
+import urllib.parse
+from unittest.mock import MagicMock, patch
+
+import pytest
+import requests
+
+from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
+
+
+class BaseOAuthTest:
+ """Base class for OAuth provider tests with common fixtures"""
+
+ @pytest.fixture
+ def oauth_config(self):
+ return {
+ "client_id": "test_client_id",
+ "client_secret": "test_client_secret",
+ "redirect_uri": "http://localhost/callback",
+ }
+
+ @pytest.fixture
+ def mock_response(self):
+ response = MagicMock()
+ response.json.return_value = {}
+ return response
+
+ def parse_auth_url(self, url):
+ """Helper to parse authorization URL"""
+ parsed = urllib.parse.urlparse(url)
+ params = urllib.parse.parse_qs(parsed.query)
+ return parsed, params
+
+
+class TestGitHubOAuth(BaseOAuthTest):
+ @pytest.fixture
+ def oauth(self, oauth_config):
+ return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
+
+ @pytest.mark.parametrize(
+ ("invite_token", "expected_state"),
+ [
+ (None, None),
+ ("test_invite_token", "test_invite_token"),
+ ("", None),
+ ],
+ )
+ def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
+ url = oauth.get_authorization_url(invite_token)
+ parsed, params = self.parse_auth_url(url)
+
+ assert parsed.scheme == "https"
+ assert parsed.netloc == "github.com"
+ assert parsed.path == "/login/oauth/authorize"
+ assert params["client_id"][0] == oauth_config["client_id"]
+ assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
+ assert params["scope"][0] == "user:email"
+
+ if expected_state:
+ assert params["state"][0] == expected_state
+ else:
+ assert "state" not in params
+
+ @pytest.mark.parametrize(
+ ("response_data", "expected_token", "should_raise"),
+ [
+ ({"access_token": "test_token"}, "test_token", False),
+ ({"error": "invalid_grant"}, None, True),
+ ({}, None, True),
+ ],
+ )
+ @patch("requests.post")
+ def test_should_retrieve_access_token(
+ self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
+ ):
+ mock_response.json.return_value = response_data
+ mock_post.return_value = mock_response
+
+ if should_raise:
+ with pytest.raises(ValueError) as exc_info:
+ oauth.get_access_token("test_code")
+ assert "Error in GitHub OAuth" in str(exc_info.value)
+ else:
+ token = oauth.get_access_token("test_code")
+ assert token == expected_token
+
+ @pytest.mark.parametrize(
+ ("user_data", "email_data", "expected_email"),
+ [
+ # User with primary email
+ (
+ {"id": 12345, "login": "testuser", "name": "Test User"},
+ [
+ {"email": "secondary@example.com", "primary": False},
+ {"email": "primary@example.com", "primary": True},
+ ],
+ "primary@example.com",
+ ),
+ # User with no emails - fallback to noreply
+ ({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
+ # User with only secondary email - fallback to noreply
+ (
+ {"id": 12345, "login": "testuser", "name": "Test User"},
+ [{"email": "secondary@example.com", "primary": False}],
+ "12345+testuser@users.noreply.github.com",
+ ),
+ ],
+ )
+ @patch("requests.get")
+ def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
+ user_response = MagicMock()
+ user_response.json.return_value = user_data
+
+ email_response = MagicMock()
+ email_response.json.return_value = email_data
+
+ mock_get.side_effect = [user_response, email_response]
+
+ user_info = oauth.get_user_info("test_token")
+
+ assert user_info.id == str(user_data["id"])
+ assert user_info.name == user_data["name"]
+ assert user_info.email == expected_email
+
+ @patch("requests.get")
+ def test_should_handle_network_errors(self, mock_get, oauth):
+ mock_get.side_effect = requests.exceptions.RequestException("Network error")
+
+ with pytest.raises(requests.exceptions.RequestException):
+ oauth.get_raw_user_info("test_token")
+
+
+class TestGoogleOAuth(BaseOAuthTest):
+ @pytest.fixture
+ def oauth(self, oauth_config):
+ return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
+
+ @pytest.mark.parametrize(
+ ("invite_token", "expected_state"),
+ [
+ (None, None),
+ ("test_invite_token", "test_invite_token"),
+ ("", None),
+ ],
+ )
+ def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
+ url = oauth.get_authorization_url(invite_token)
+ parsed, params = self.parse_auth_url(url)
+
+ assert parsed.scheme == "https"
+ assert parsed.netloc == "accounts.google.com"
+ assert parsed.path == "/o/oauth2/v2/auth"
+ assert params["client_id"][0] == oauth_config["client_id"]
+ assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
+ assert params["response_type"][0] == "code"
+ assert params["scope"][0] == "openid email"
+
+ if expected_state:
+ assert params["state"][0] == expected_state
+ else:
+ assert "state" not in params
+
+ @pytest.mark.parametrize(
+ ("response_data", "expected_token", "should_raise"),
+ [
+ ({"access_token": "test_token"}, "test_token", False),
+ ({"error": "invalid_grant"}, None, True),
+ ({}, None, True),
+ ],
+ )
+ @patch("requests.post")
+ def test_should_retrieve_access_token(
+ self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
+ ):
+ mock_response.json.return_value = response_data
+ mock_post.return_value = mock_response
+
+ if should_raise:
+ with pytest.raises(ValueError) as exc_info:
+ oauth.get_access_token("test_code")
+ assert "Error in Google OAuth" in str(exc_info.value)
+ else:
+ token = oauth.get_access_token("test_code")
+ assert token == expected_token
+
+ mock_post.assert_called_once_with(
+ oauth._TOKEN_URL,
+ data={
+ "client_id": oauth_config["client_id"],
+ "client_secret": oauth_config["client_secret"],
+ "code": "test_code",
+ "grant_type": "authorization_code",
+ "redirect_uri": oauth_config["redirect_uri"],
+ },
+ headers={"Accept": "application/json"},
+ )
+
+ @pytest.mark.parametrize(
+ ("user_data", "expected_name"),
+ [
+ ({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
+ ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
+ ],
+ )
+ @patch("requests.get")
+ def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
+ mock_response.json.return_value = user_data
+ mock_get.return_value = mock_response
+
+ user_info = oauth.get_user_info("test_token")
+
+ assert user_info.id == user_data["sub"]
+ assert user_info.name == expected_name
+ assert user_info.email == user_data["email"]
+
+ mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
+
+ @pytest.mark.parametrize(
+ "exception_type",
+ [
+ requests.exceptions.HTTPError,
+ requests.exceptions.ConnectionError,
+ requests.exceptions.Timeout,
+ ],
+ )
+ @patch("requests.get")
+ def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
+ mock_response = MagicMock()
+ mock_response.raise_for_status.side_effect = exception_type("Error")
+ mock_get.return_value = mock_response
+
+ with pytest.raises(exception_type):
+ oauth.get_raw_user_info("invalid_token")
+
+
+class TestOAuthUserInfo:
+ @pytest.mark.parametrize(
+ "user_data",
+ [
+ {"id": "123", "name": "Test User", "email": "test@example.com"},
+ {"id": "456", "name": "", "email": "user@domain.com"},
+ {"id": "789", "name": "Another User", "email": "another@test.org"},
+ ],
+ )
+ def test_should_create_user_info_dataclass(self, user_data):
+ user_info = OAuthUserInfo(**user_data)
+
+ assert user_info.id == user_data["id"]
+ assert user_info.name == user_data["name"]
+ assert user_info.email == user_data["email"]
diff --git a/api/tests/unit_tests/libs/test_passport.py b/api/tests/unit_tests/libs/test_passport.py
new file mode 100644
index 0000000000..f33484c18d
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_passport.py
@@ -0,0 +1,205 @@
+from datetime import UTC, datetime, timedelta
+from unittest.mock import patch
+
+import jwt
+import pytest
+from werkzeug.exceptions import Unauthorized
+
+from libs.passport import PassportService
+
+
+class TestPassportService:
+ """Test PassportService JWT operations"""
+
+ @pytest.fixture
+ def passport_service(self):
+ """Create PassportService instance with test secret key"""
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "test-secret-key-for-testing"
+ return PassportService()
+
+ @pytest.fixture
+ def another_passport_service(self):
+ """Create another PassportService instance with different secret key"""
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "another-secret-key-for-testing"
+ return PassportService()
+
+ # Core functionality tests
+ def test_should_issue_and_verify_token(self, passport_service):
+ """Test complete JWT lifecycle: issue and verify"""
+ payload = {"user_id": "123", "app_code": "test-app"}
+ token = passport_service.issue(payload)
+
+ # Verify token format
+ assert isinstance(token, str)
+ assert len(token.split(".")) == 3 # JWT format: header.payload.signature
+
+ # Verify token content
+ decoded = passport_service.verify(token)
+ assert decoded == payload
+
+ def test_should_handle_different_payload_types(self, passport_service):
+ """Test issuing and verifying tokens with different payload types"""
+ test_cases = [
+ {"string": "value"},
+ {"number": 42},
+ {"float": 3.14},
+ {"boolean": True},
+ {"null": None},
+ {"array": [1, 2, 3]},
+ {"nested": {"key": "value"}},
+ {"unicode": "中文测试"},
+ {"emoji": "🔐"},
+ {}, # Empty payload
+ ]
+
+ for payload in test_cases:
+ token = passport_service.issue(payload)
+ decoded = passport_service.verify(token)
+ assert decoded == payload
+
+ # Security tests
+ def test_should_reject_modified_token(self, passport_service):
+ """Test that any modification to token invalidates it"""
+ token = passport_service.issue({"user": "test"})
+
+ # Test multiple modification points
+ test_positions = [0, len(token) // 3, len(token) // 2, len(token) - 1]
+
+ for pos in test_positions:
+ if pos < len(token) and token[pos] != ".":
+ # Change one character
+ tampered = token[:pos] + ("X" if token[pos] != "X" else "Y") + token[pos + 1 :]
+ with pytest.raises(Unauthorized):
+ passport_service.verify(tampered)
+
+ def test_should_reject_token_with_different_secret_key(self, passport_service, another_passport_service):
+ """Test key isolation - token from one service should not work with another"""
+ payload = {"user_id": "123", "app_code": "test-app"}
+ token = passport_service.issue(payload)
+
+ with pytest.raises(Unauthorized) as exc_info:
+ another_passport_service.verify(token)
+ assert str(exc_info.value) == "401 Unauthorized: Invalid token signature."
+
+ def test_should_use_hs256_algorithm(self, passport_service):
+ """Test that HS256 algorithm is used for signing"""
+ payload = {"test": "data"}
+ token = passport_service.issue(payload)
+
+ # Decode header without relying on JWT internals
+ # Use jwt.get_unverified_header which is a public API
+ header = jwt.get_unverified_header(token)
+ assert header["alg"] == "HS256"
+
+ def test_should_reject_token_with_wrong_algorithm(self, passport_service):
+ """Test rejection of token signed with different algorithm"""
+ payload = {"user_id": "123"}
+
+ # Create token with different algorithm
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "test-secret-key-for-testing"
+ # Create token with HS512 instead of HS256
+ wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512")
+
+ # Should fail because service expects HS256
+ # InvalidAlgorithmError is now caught by PyJWTError handler
+ with pytest.raises(Unauthorized) as exc_info:
+ passport_service.verify(wrong_alg_token)
+ assert str(exc_info.value) == "401 Unauthorized: Invalid token."
+
+ # Exception handling tests
+ def test_should_handle_invalid_tokens(self, passport_service):
+ """Test handling of various invalid token formats"""
+ invalid_tokens = [
+ ("not.a.token", "Invalid token."),
+ ("invalid-jwt-format", "Invalid token."),
+ ("xxx.yyy.zzz", "Invalid token."),
+ ("a.b", "Invalid token."), # Missing signature
+ ("", "Invalid token."), # Empty string
+ (" ", "Invalid token."), # Whitespace
+ (None, "Invalid token."), # None value
+ # Malformed base64
+ ("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.INVALID_BASE64!@#$.signature", "Invalid token."),
+ ]
+
+ for invalid_token, expected_message in invalid_tokens:
+ with pytest.raises(Unauthorized) as exc_info:
+ passport_service.verify(invalid_token)
+ assert expected_message in str(exc_info.value)
+
+ def test_should_reject_expired_token(self, passport_service):
+ """Test rejection of expired token"""
+ past_time = datetime.now(UTC) - timedelta(hours=1)
+ payload = {"user_id": "123", "exp": past_time.timestamp()}
+
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "test-secret-key-for-testing"
+ token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
+
+ with pytest.raises(Unauthorized) as exc_info:
+ passport_service.verify(token)
+ assert str(exc_info.value) == "401 Unauthorized: Token has expired."
+
+ # Configuration tests
+ def test_should_handle_empty_secret_key(self):
+ """Test behavior when SECRET_KEY is empty"""
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = ""
+ service = PassportService()
+
+ # Empty secret key should still work but is insecure
+ payload = {"test": "data"}
+ token = service.issue(payload)
+ decoded = service.verify(token)
+ assert decoded == payload
+
+ def test_should_handle_none_secret_key(self):
+ """Test behavior when SECRET_KEY is None"""
+ with patch("libs.passport.dify_config") as mock_config:
+ mock_config.SECRET_KEY = None
+ service = PassportService()
+
+ payload = {"test": "data"}
+ # JWT library will raise TypeError when secret is None
+ with pytest.raises((TypeError, jwt.exceptions.InvalidKeyError)):
+ service.issue(payload)
+
+ # Boundary condition tests
+ def test_should_handle_large_payload(self, passport_service):
+ """Test handling of large payload"""
+ # Test with 100KB instead of 1MB for faster tests
+ large_data = "x" * (100 * 1024)
+ payload = {"data": large_data}
+
+ token = passport_service.issue(payload)
+ decoded = passport_service.verify(token)
+
+ assert decoded["data"] == large_data
+
+ def test_should_handle_special_characters_in_payload(self, passport_service):
+ """Test handling of special characters in payload"""
+ special_payloads = [
+ {"special": "!@#$%^&*()"},
+ {"quotes": 'He said "Hello"'},
+ {"backslash": "path\\to\\file"},
+ {"newline": "line1\nline2"},
+ {"unicode": "🔐🔑🛡️"},
+ {"mixed": "Test123!@#中文🔐"},
+ ]
+
+ for payload in special_payloads:
+ token = passport_service.issue(payload)
+ decoded = passport_service.verify(token)
+ assert decoded == payload
+
+ def test_should_catch_generic_pyjwt_errors(self, passport_service):
+ """Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
+ # Mock jwt.decode to raise a generic PyJWTError
+ with patch("libs.passport.jwt.decode") as mock_decode:
+ mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error")
+
+ with pytest.raises(Unauthorized) as exc_info:
+ passport_service.verify("some-token")
+ assert str(exc_info.value) == "401 Unauthorized: Invalid token."
diff --git a/api/tests/unit_tests/libs/test_password.py b/api/tests/unit_tests/libs/test_password.py
new file mode 100644
index 0000000000..79fc792cc5
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_password.py
@@ -0,0 +1,74 @@
+import base64
+import binascii
+import os
+
+import pytest
+
+from libs.password import compare_password, hash_password, valid_password
+
+
+class TestValidPassword:
+ """Test password format validation"""
+
+ def test_should_accept_valid_passwords(self):
+ """Test accepting valid password formats"""
+ assert valid_password("password123") == "password123"
+ assert valid_password("test1234") == "test1234"
+ assert valid_password("Test123456") == "Test123456"
+
+ def test_should_reject_invalid_passwords(self):
+ """Test rejecting invalid password formats"""
+ # Too short
+ with pytest.raises(ValueError) as exc_info:
+ valid_password("abc123")
+ assert "Password must contain letters and numbers" in str(exc_info.value)
+
+ # No numbers
+ with pytest.raises(ValueError):
+ valid_password("abcdefgh")
+
+ # No letters
+ with pytest.raises(ValueError):
+ valid_password("12345678")
+
+ # Empty
+ with pytest.raises(ValueError):
+ valid_password("")
+
+
+class TestPasswordHashing:
+ """Test password hashing and comparison"""
+
+ def setup_method(self):
+ """Setup test data"""
+ self.password = "test123password"
+ self.salt = os.urandom(16)
+ self.salt_base64 = base64.b64encode(self.salt).decode()
+
+ password_hash = hash_password(self.password, self.salt)
+ self.password_hash_base64 = base64.b64encode(password_hash).decode()
+
+ def test_should_verify_correct_password(self):
+ """Test correct password verification"""
+ result = compare_password(self.password, self.password_hash_base64, self.salt_base64)
+ assert result is True
+
+ def test_should_reject_wrong_password(self):
+ """Test rejection of incorrect passwords"""
+ result = compare_password("wrongpassword", self.password_hash_base64, self.salt_base64)
+ assert result is False
+
+ def test_should_handle_invalid_base64(self):
+ """Test handling of invalid base64 data"""
+ # Invalid base64 hash
+ with pytest.raises(binascii.Error):
+ compare_password(self.password, "invalid_base64!", self.salt_base64)
+
+ # Invalid base64 salt
+ with pytest.raises(binascii.Error):
+ compare_password(self.password, self.password_hash_base64, "invalid_base64!")
+
+ def test_should_be_case_sensitive(self):
+ """Test password case sensitivity"""
+ result = compare_password(self.password.upper(), self.password_hash_base64, self.salt_base64)
+ assert result is False
diff --git a/api/tests/unit_tests/libs/test_uuid_utils.py b/api/tests/unit_tests/libs/test_uuid_utils.py
new file mode 100644
index 0000000000..7dbda95f45
--- /dev/null
+++ b/api/tests/unit_tests/libs/test_uuid_utils.py
@@ -0,0 +1,351 @@
+import struct
+import time
+import uuid
+from unittest import mock
+
+import pytest
+from hypothesis import given
+from hypothesis import strategies as st
+
+from libs.uuid_utils import _create_uuidv7_bytes, uuidv7, uuidv7_boundary, uuidv7_timestamp
+
+
+# Tests for private helper function _create_uuidv7_bytes
+def test_create_uuidv7_bytes_basic_structure():
+ """Test basic byte structure creation."""
+ timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Should be exactly 16 bytes
+ assert len(result) == 16
+ assert isinstance(result, bytes)
+
+ # Create UUID from bytes to verify it's valid
+ uuid_obj = uuid.UUID(bytes=result)
+ assert uuid_obj.version == 7
+
+
+def test_create_uuidv7_bytes_timestamp_encoding():
+ """Test timestamp is correctly encoded in first 48 bits."""
+ timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ random_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Extract timestamp from first 6 bytes
+ timestamp_bytes = b"\x00\x00" + result[0:6]
+ extracted_timestamp = struct.unpack(">Q", timestamp_bytes)[0]
+
+ assert extracted_timestamp == timestamp_ms
+
+
+def test_create_uuidv7_bytes_version_bits():
+ """Test version bits are set to 7."""
+ timestamp_ms = 1609459200000
+ random_bytes = b"\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00" # Set first 2 bytes to all 1s
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Extract version from bytes 6-7
+ version_and_rand_a = struct.unpack(">H", result[6:8])[0]
+ version = (version_and_rand_a >> 12) & 0x0F
+
+ assert version == 7
+
+
+def test_create_uuidv7_bytes_variant_bits():
+ """Test variant bits are set correctly."""
+ timestamp_ms = 1609459200000
+ random_bytes = b"\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00" # Set byte 8 to all 1s
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Check variant bits in byte 8 (should be 10xxxxxx)
+ variant_byte = result[8]
+ variant_bits = (variant_byte >> 6) & 0b11
+
+ assert variant_bits == 0b10 # Should be binary 10
+
+
+def test_create_uuidv7_bytes_random_data():
+ """Test random bytes are placed correctly."""
+ timestamp_ms = 1609459200000
+ random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22"
+
+ result = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ # Check random data A (12 bits from bytes 6-7, excluding version)
+ version_and_rand_a = struct.unpack(">H", result[6:8])[0]
+ rand_a = version_and_rand_a & 0x0FFF
+ expected_rand_a = struct.unpack(">H", random_bytes[0:2])[0] & 0x0FFF
+ assert rand_a == expected_rand_a
+
+ # Check random data B (bytes 8-15, with variant bits preserved)
+ # Byte 8 should have variant bits set but preserve lower 6 bits
+ expected_byte_8 = (random_bytes[2] & 0x3F) | 0x80
+ assert result[8] == expected_byte_8
+
+ # Bytes 9-15 should match random_bytes[3:10]
+ assert result[9:16] == random_bytes[3:10]
+
+
+def test_create_uuidv7_bytes_zero_random():
+ """Test with zero random bytes (boundary case)."""
+ timestamp_ms = 1609459200000
+ zero_random_bytes = b"\x00" * 10
+
+ result = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
+
+ # Should still be valid UUIDv7
+ uuid_obj = uuid.UUID(bytes=result)
+ assert uuid_obj.version == 7
+
+ # Version bits should be 0x7000
+ version_and_rand_a = struct.unpack(">H", result[6:8])[0]
+ assert version_and_rand_a == 0x7000
+
+ # Variant byte should be 0x80 (variant bits + zero random bits)
+ assert result[8] == 0x80
+
+ # Remaining bytes should be zero
+ assert result[9:16] == b"\x00" * 7
+
+
+def test_uuidv7_basic_generation():
+ """Test basic UUID generation produces valid UUIDv7."""
+ result = uuidv7()
+
+ # Should be a UUID object
+ assert isinstance(result, uuid.UUID)
+
+ # Should be version 7
+ assert result.version == 7
+
+ # Should have correct variant (RFC 4122 variant)
+ # Variant bits should be 10xxxxxx (0x80-0xBF range)
+ variant_byte = result.bytes[8]
+ assert (variant_byte >> 6) == 0b10
+
+
+def test_uuidv7_with_custom_timestamp():
+ """Test UUID generation with custom timestamp."""
+ custom_timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ result = uuidv7(custom_timestamp)
+
+ assert isinstance(result, uuid.UUID)
+ assert result.version == 7
+
+ # Extract and verify timestamp
+ extracted_timestamp = uuidv7_timestamp(result)
+ assert isinstance(extracted_timestamp, int)
+ assert extracted_timestamp == custom_timestamp # Exact match for integer milliseconds
+
+
+def test_uuidv7_with_none_timestamp(monkeypatch):
+ """Test UUID generation with None timestamp uses current time."""
+ mock_time = 1609459200
+ mock_time_func = mock.Mock(return_value=mock_time)
+ monkeypatch.setattr("time.time", mock_time_func)
+ result = uuidv7(None)
+
+ assert isinstance(result, uuid.UUID)
+ assert result.version == 7
+
+ # Should use the mocked current time (converted to milliseconds)
+ assert mock_time_func.called
+ extracted_timestamp = uuidv7_timestamp(result)
+ assert extracted_timestamp == mock_time * 1000 # 1609459200.0 * 1000
+
+
+def test_uuidv7_time_ordering():
+ """Test that sequential UUIDs have increasing timestamps."""
+ # Generate UUIDs with incrementing timestamps (in milliseconds)
+ timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
+ timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
+ timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
+
+ uuid1 = uuidv7(timestamp1)
+ uuid2 = uuidv7(timestamp2)
+ uuid3 = uuidv7(timestamp3)
+
+ # Extract timestamps
+ ts1 = uuidv7_timestamp(uuid1)
+ ts2 = uuidv7_timestamp(uuid2)
+ ts3 = uuidv7_timestamp(uuid3)
+
+ # Should be in ascending order
+ assert ts1 < ts2 < ts3
+
+ # UUIDs should be lexicographically ordered by their string representation
+ # due to time-ordering property of UUIDv7
+ uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
+ assert uuid_strings == sorted(uuid_strings)
+
+
+def test_uuidv7_uniqueness():
+ """Test that multiple calls generate different UUIDs."""
+ # Generate multiple UUIDs with the same timestamp (in milliseconds)
+ timestamp = 1609459200000
+ uuids = [uuidv7(timestamp) for _ in range(100)]
+
+ # All should be unique despite same timestamp (due to random bits)
+ assert len(set(uuids)) == 100
+
+ # All should have the same extracted timestamp
+ for uuid_obj in uuids:
+ extracted_ts = uuidv7_timestamp(uuid_obj)
+ assert extracted_ts == timestamp
+
+
+def test_uuidv7_timestamp_error_handling_wrong_version():
+ """Test error handling for non-UUIDv7 inputs."""
+
+ uuid_v4 = uuid.uuid4()
+ with pytest.raises(ValueError) as exc_ctx:
+ uuidv7_timestamp(uuid_v4)
+ assert "Expected UUIDv7 (version 7)" in str(exc_ctx.value)
+ assert f"got version {uuid_v4.version}" in str(exc_ctx.value)
+
+
+@given(st.integers(max_value=2**48 - 1, min_value=0))
+def test_uuidv7_timestamp_round_trip(timestamp_ms):
+ # Generate UUID with timestamp
+ uuid_obj = uuidv7(timestamp_ms)
+
+ # Extract timestamp back
+ extracted_timestamp = uuidv7_timestamp(uuid_obj)
+
+ # Should match exactly for integer millisecond timestamps
+ assert extracted_timestamp == timestamp_ms
+
+
+def test_uuidv7_timestamp_edge_cases():
+ """Test timestamp extraction with edge case values."""
+ # Test with very small timestamp
+ small_timestamp = 1 # 1ms after epoch
+ uuid_small = uuidv7(small_timestamp)
+ extracted_small = uuidv7_timestamp(uuid_small)
+ assert extracted_small == small_timestamp
+
+ # Test with large timestamp (year 2038+)
+ large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
+ uuid_large = uuidv7(large_timestamp)
+ extracted_large = uuidv7_timestamp(uuid_large)
+ assert extracted_large == large_timestamp
+
+
+def test_uuidv7_boundary_basic_generation():
+ """Test basic boundary UUID generation with a known timestamp."""
+ timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ result = uuidv7_boundary(timestamp)
+
+ # Should be a UUID object
+ assert isinstance(result, uuid.UUID)
+
+ # Should be version 7
+ assert result.version == 7
+
+ # Should have correct variant (RFC 4122 variant)
+ # Variant bits should be 10xxxxxx (0x80-0xBF range)
+ variant_byte = result.bytes[8]
+ assert (variant_byte >> 6) == 0b10
+
+
+def test_uuidv7_boundary_timestamp_extraction():
+ """Test that boundary UUID timestamp can be extracted correctly."""
+ timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+ boundary_uuid = uuidv7_boundary(timestamp)
+
+ # Extract timestamp using existing function
+ extracted_timestamp = uuidv7_timestamp(boundary_uuid)
+
+ # Should match exactly
+ assert extracted_timestamp == timestamp
+
+
+def test_uuidv7_boundary_deterministic():
+ """Test that boundary UUIDs are deterministic for same timestamp."""
+ timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+
+ # Generate multiple boundary UUIDs with same timestamp
+ uuid1 = uuidv7_boundary(timestamp)
+ uuid2 = uuidv7_boundary(timestamp)
+ uuid3 = uuidv7_boundary(timestamp)
+
+ # Should all be identical
+ assert uuid1 == uuid2 == uuid3
+ assert str(uuid1) == str(uuid2) == str(uuid3)
+
+
+def test_uuidv7_boundary_is_minimum():
+ """Test that boundary UUID is lexicographically smaller than regular UUIDs."""
+ timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds
+
+ # Generate boundary UUID
+ boundary_uuid = uuidv7_boundary(timestamp)
+
+ # Generate multiple regular UUIDs with same timestamp
+ regular_uuids = [uuidv7(timestamp) for _ in range(50)]
+
+ # Boundary UUID should be lexicographically smaller than all regular UUIDs
+ boundary_str = str(boundary_uuid)
+ for regular_uuid in regular_uuids:
+ regular_str = str(regular_uuid)
+ assert boundary_str < regular_str, f"Boundary {boundary_str} should be < regular {regular_str}"
+
+ # Also test with bytes comparison
+ boundary_bytes = boundary_uuid.bytes
+ for regular_uuid in regular_uuids:
+ regular_bytes = regular_uuid.bytes
+ assert boundary_bytes < regular_bytes
+
+
+def test_uuidv7_boundary_different_timestamps():
+ """Test that boundary UUIDs with different timestamps are ordered correctly."""
+ timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC
+ timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC
+ timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC
+
+ uuid1 = uuidv7_boundary(timestamp1)
+ uuid2 = uuidv7_boundary(timestamp2)
+ uuid3 = uuidv7_boundary(timestamp3)
+
+ # Extract timestamps to verify
+ ts1 = uuidv7_timestamp(uuid1)
+ ts2 = uuidv7_timestamp(uuid2)
+ ts3 = uuidv7_timestamp(uuid3)
+
+ # Should be in ascending order
+ assert ts1 < ts2 < ts3
+
+ # UUIDs should be lexicographically ordered
+ uuid_strings = [str(uuid1), str(uuid2), str(uuid3)]
+ assert uuid_strings == sorted(uuid_strings)
+
+ # Bytes should also be ordered
+ assert uuid1.bytes < uuid2.bytes < uuid3.bytes
+
+
+def test_uuidv7_boundary_edge_cases():
+ """Test boundary UUID generation with edge case timestamp values."""
+ # Test with timestamp 0 (Unix epoch)
+ epoch_uuid = uuidv7_boundary(0)
+ assert isinstance(epoch_uuid, uuid.UUID)
+ assert epoch_uuid.version == 7
+ assert uuidv7_timestamp(epoch_uuid) == 0
+
+ # Test with very large timestamp values
+ large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds
+ large_uuid = uuidv7_boundary(large_timestamp)
+ assert isinstance(large_uuid, uuid.UUID)
+ assert large_uuid.version == 7
+ assert uuidv7_timestamp(large_uuid) == large_timestamp
+
+ # Test with current time
+ current_time = int(time.time() * 1000)
+ current_uuid = uuidv7_boundary(current_time)
+ assert isinstance(current_uuid, uuid.UUID)
+ assert current_uuid.version == 7
+ assert uuidv7_timestamp(current_uuid) == current_time
diff --git a/api/tests/unit_tests/models/test_types_enum_text.py b/api/tests/unit_tests/models/test_types_enum_text.py
index 3afa0f17a0..e4061b72c7 100644
--- a/api/tests/unit_tests/models/test_types_enum_text.py
+++ b/api/tests/unit_tests/models/test_types_enum_text.py
@@ -6,7 +6,7 @@ import pytest
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
from sqlalchemy import insert
-from sqlalchemy.orm import DeclarativeBase, Mapped, Session
+from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
from sqlalchemy.sql.sqltypes import VARCHAR
from models.types import EnumText
@@ -32,22 +32,26 @@ class _EnumWithLongValue(StrEnum):
class _User(_Base):
__tablename__ = "users"
- id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
- name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False)
- user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
- user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True)
+ id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
+ name: Mapped[str] = mapped_column(sa.String(length=255), nullable=False)
+ user_type: Mapped[_UserType] = mapped_column(
+ EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal
+ )
+ user_type_nullable: Mapped[_UserType | None] = mapped_column(EnumText(enum_class=_UserType), nullable=True)
class _ColumnTest(_Base):
__tablename__ = "column_test"
- id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
+ id: Mapped[int] = mapped_column(sa.Integer, primary_key=True)
- user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
- explicit_length: Mapped[_UserType | None] = sa.Column(
+ user_type: Mapped[_UserType] = mapped_column(
+ EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal
+ )
+ explicit_length: Mapped[_UserType | None] = mapped_column(
EnumText(_UserType, length=50), nullable=True, default=_UserType.normal
)
- long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
+ long_value: Mapped[_EnumWithLongValue] = mapped_column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
_T = TypeVar("_T")
@@ -110,12 +114,12 @@ class TestEnumText:
session.commit()
with Session(engine) as session:
- user = session.query(_User).filter(_User.id == admin_user_id).first()
+ user = session.query(_User).where(_User.id == admin_user_id).first()
assert user.user_type == _UserType.admin
assert user.user_type_nullable is None
with Session(engine) as session:
- user = session.query(_User).filter(_User.id == normal_user_id).first()
+ user = session.query(_User).where(_User.id == normal_user_id).first()
assert user.user_type == _UserType.normal
assert user.user_type_nullable == _UserType.normal
@@ -184,4 +188,4 @@ class TestEnumText:
with pytest.raises(ValueError) as exc:
with Session(engine) as session:
- _user = session.query(_User).filter(_User.id == 1).first()
+ _user = session.query(_User).where(_User.id == 1).first()
diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py
index 69163d48bd..5bc77ad0ef 100644
--- a/api/tests/unit_tests/models/test_workflow.py
+++ b/api/tests/unit_tests/models/test_workflow.py
@@ -9,6 +9,7 @@ from core.file.models import File
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from core.variables.segments import IntegerSegment, Segment
from factories.variable_factory import build_segment
+from models.model import EndUser
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
@@ -43,7 +44,7 @@ def test_environment_variables():
)
# Mock current_user as an EndUser
- mock_user = mock.Mock()
+ mock_user = mock.Mock(spec=EndUser)
mock_user.tenant_id = "tenant_id"
with (
@@ -90,7 +91,7 @@ def test_update_environment_variables():
)
# Mock current_user as an EndUser
- mock_user = mock.Mock()
+ mock_user = mock.Mock(spec=EndUser)
mock_user.tenant_id = "tenant_id"
with (
@@ -136,7 +137,7 @@ def test_to_dict():
# Create some EnvironmentVariable instances
# Mock current_user as an EndUser
- mock_user = mock.Mock()
+ mock_user = mock.Mock(spec=EndUser)
mock_user.tenant_id = "tenant_id"
with (
diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py
index 643efb0a0c..c60800c493 100644
--- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py
+++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py
@@ -137,37 +137,6 @@ def test_save_with_existing_tenant_id(repository, session):
session_obj.merge.assert_called_once_with(modified_execution)
-def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
- """Test get_by_node_execution_id method."""
- session_obj, _ = session
- # Set up mock
- mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
- mock_stmt = mocker.MagicMock()
- mock_select.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
-
- # Create a properly configured mock execution
- mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
- configure_mock_execution(mock_execution)
- session_obj.scalar.return_value = mock_execution
-
- # Create a mock domain model to be returned by _to_domain_model
- mock_domain_model = mocker.MagicMock()
- # Mock the _to_domain_model method to return our mock domain model
- repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
-
- # Call method
- result = repository.get_by_node_execution_id("test-node-execution-id")
-
- # Assert select was called with correct parameters
- mock_select.assert_called_once()
- session_obj.scalar.assert_called_once_with(mock_stmt)
- # Assert _to_domain_model was called with the mock execution
- repository._to_domain_model.assert_called_once_with(mock_execution)
- # Assert the result is our mock domain model
- assert result is mock_domain_model
-
-
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
"""Test get_by_workflow_run method."""
session_obj, _ = session
@@ -202,88 +171,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
assert result[0] is mock_domain_model
-def test_get_running_executions(repository, session, mocker: MockerFixture):
- """Test get_running_executions method."""
- session_obj, _ = session
- # Set up mock
- mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
- mock_stmt = mocker.MagicMock()
- mock_select.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
-
- # Create a properly configured mock execution
- mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
- configure_mock_execution(mock_execution)
- session_obj.scalars.return_value.all.return_value = [mock_execution]
-
- # Create a mock domain model to be returned by _to_domain_model
- mock_domain_model = mocker.MagicMock()
- # Mock the _to_domain_model method to return our mock domain model
- repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
-
- # Call method
- result = repository.get_running_executions("test-workflow-run-id")
-
- # Assert select was called with correct parameters
- mock_select.assert_called_once()
- session_obj.scalars.assert_called_once_with(mock_stmt)
- # Assert _to_domain_model was called with the mock execution
- repository._to_domain_model.assert_called_once_with(mock_execution)
- # Assert the result contains our mock domain model
- assert len(result) == 1
- assert result[0] is mock_domain_model
-
-
-def test_update_via_save(repository, session):
- """Test updating an existing record via save method."""
- session_obj, _ = session
- # Create a mock execution
- execution = MagicMock(spec=WorkflowNodeExecutionModel)
- execution.tenant_id = None
- execution.app_id = None
- execution.inputs = None
- execution.process_data = None
- execution.outputs = None
- execution.metadata = None
-
- # Mock the to_db_model method to return the execution itself
- # This simulates the behavior of setting tenant_id and app_id
- repository.to_db_model = MagicMock(return_value=execution)
-
- # Call save method to update an existing record
- repository.save(execution)
-
- # Assert to_db_model was called with the execution
- repository.to_db_model.assert_called_once_with(execution)
-
- # Assert session.merge was called (for updates)
- session_obj.merge.assert_called_once_with(execution)
-
-
-def test_clear(repository, session, mocker: MockerFixture):
- """Test clear method."""
- session_obj, _ = session
- # Set up mock
- mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete")
- mock_stmt = mocker.MagicMock()
- mock_delete.return_value = mock_stmt
- mock_stmt.where.return_value = mock_stmt
-
- # Mock the execute result with rowcount
- mock_result = mocker.MagicMock()
- mock_result.rowcount = 5 # Simulate 5 records deleted
- session_obj.execute.return_value = mock_result
-
- # Call method
- repository.clear()
-
- # Assert delete was called with correct parameters
- mock_delete.assert_called_once_with(WorkflowNodeExecutionModel)
- mock_stmt.where.assert_called()
- session_obj.execute.assert_called_once_with(mock_stmt)
- session_obj.commit.assert_called_once()
-
-
def test_to_db_model(repository):
"""Test to_db_model method."""
# Create a domain model
diff --git a/api/tests/unit_tests/services/auth/__init__.py b/api/tests/unit_tests/services/auth/__init__.py
new file mode 100644
index 0000000000..852a892730
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/__init__.py
@@ -0,0 +1 @@
+# API authentication service test module
diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_base.py b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py
new file mode 100644
index 0000000000..b5d91ef3fb
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_api_key_auth_base.py
@@ -0,0 +1,49 @@
+import pytest
+
+from services.auth.api_key_auth_base import ApiKeyAuthBase
+
+
+class ConcreteApiKeyAuth(ApiKeyAuthBase):
+ """Concrete implementation for testing abstract base class"""
+
+ def validate_credentials(self):
+ return True
+
+
+class TestApiKeyAuthBase:
+ def test_should_store_credentials_on_init(self):
+ """Test that credentials are properly stored during initialization"""
+ credentials = {"api_key": "test_key", "auth_type": "bearer"}
+ auth = ConcreteApiKeyAuth(credentials)
+ assert auth.credentials == credentials
+
+ def test_should_not_instantiate_abstract_class(self):
+ """Test that ApiKeyAuthBase cannot be instantiated directly"""
+ credentials = {"api_key": "test_key"}
+
+ with pytest.raises(TypeError) as exc_info:
+ ApiKeyAuthBase(credentials)
+
+ assert "Can't instantiate abstract class" in str(exc_info.value)
+ assert "validate_credentials" in str(exc_info.value)
+
+ def test_should_allow_subclass_implementation(self):
+ """Test that subclasses can properly implement the abstract method"""
+ credentials = {"api_key": "test_key", "auth_type": "bearer"}
+ auth = ConcreteApiKeyAuth(credentials)
+
+ # Should not raise any exception
+ result = auth.validate_credentials()
+ assert result is True
+
+ def test_should_handle_empty_credentials(self):
+ """Test initialization with empty credentials"""
+ credentials = {}
+ auth = ConcreteApiKeyAuth(credentials)
+ assert auth.credentials == {}
+
+ def test_should_handle_none_credentials(self):
+ """Test initialization with None credentials"""
+ credentials = None
+ auth = ConcreteApiKeyAuth(credentials)
+ assert auth.credentials is None
diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py
new file mode 100644
index 0000000000..9d9cb7c6d5
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_api_key_auth_factory.py
@@ -0,0 +1,81 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from services.auth.api_key_auth_factory import ApiKeyAuthFactory
+from services.auth.auth_type import AuthType
+
+
+class TestApiKeyAuthFactory:
+ """Test cases for ApiKeyAuthFactory"""
+
+ @pytest.mark.parametrize(
+ ("provider", "auth_class_path"),
+ [
+ (AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
+ (AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
+ (AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
+ ],
+ )
+ def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
+ """Test getting auth factory for all valid providers"""
+ with patch(auth_class_path) as mock_auth:
+ auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
+ assert auth_class == mock_auth
+
+ @pytest.mark.parametrize(
+ "invalid_provider",
+ [
+ "invalid_provider",
+ "",
+ None,
+ 123,
+ "UNSUPPORTED",
+ ],
+ )
+ def test_get_apikey_auth_factory_invalid_providers(self, invalid_provider):
+ """Test getting auth factory with various invalid providers"""
+ with pytest.raises(ValueError) as exc_info:
+ ApiKeyAuthFactory.get_apikey_auth_factory(invalid_provider)
+ assert str(exc_info.value) == "Invalid provider"
+
+ @pytest.mark.parametrize(
+ ("credentials_return_value", "expected_result"),
+ [
+ (True, True),
+ (False, False),
+ ],
+ )
+ @patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
+ def test_validate_credentials_delegates_to_auth_instance(
+ self, mock_get_factory, credentials_return_value, expected_result
+ ):
+ """Test that validate_credentials delegates to auth instance correctly"""
+ # Arrange
+ mock_auth_instance = MagicMock()
+ mock_auth_instance.validate_credentials.return_value = credentials_return_value
+ mock_auth_class = MagicMock(return_value=mock_auth_instance)
+ mock_get_factory.return_value = mock_auth_class
+
+ # Act
+ factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
+ result = factory.validate_credentials()
+
+ # Assert
+ assert result is expected_result
+ mock_auth_instance.validate_credentials.assert_called_once()
+
+ @patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
+ def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
+ """Test that exceptions from auth instance are propagated"""
+ # Arrange
+ mock_auth_instance = MagicMock()
+ mock_auth_instance.validate_credentials.side_effect = Exception("Authentication error")
+ mock_auth_class = MagicMock(return_value=mock_auth_instance)
+ mock_get_factory.return_value = mock_auth_class
+
+ # Act & Assert
+ factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, {"api_key": "test_key"})
+ with pytest.raises(Exception) as exc_info:
+ factory.validate_credentials()
+ assert str(exc_info.value) == "Authentication error"
diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py
new file mode 100644
index 0000000000..dc42a04cf3
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py
@@ -0,0 +1,383 @@
+import json
+from unittest.mock import Mock, patch
+
+import pytest
+
+from models.source import DataSourceApiKeyAuthBinding
+from services.auth.api_key_auth_service import ApiKeyAuthService
+
+
+class TestApiKeyAuthService:
+ """API key authentication service security tests"""
+
+ def setup_method(self):
+ """Setup test fixtures"""
+ self.tenant_id = "test_tenant_123"
+ self.category = "search"
+ self.provider = "google"
+ self.binding_id = "binding_123"
+ self.mock_credentials = {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}}
+ self.mock_args = {"category": self.category, "provider": self.provider, "credentials": self.mock_credentials}
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_provider_auth_list_success(self, mock_session):
+ """Test get provider auth list - success scenario"""
+ # Mock database query result
+ mock_binding = Mock()
+ mock_binding.tenant_id = self.tenant_id
+ mock_binding.provider = self.provider
+ mock_binding.disabled = False
+
+ mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
+
+ result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
+
+ assert len(result) == 1
+ assert result[0].tenant_id == self.tenant_id
+ mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_provider_auth_list_empty(self, mock_session):
+ """Test get provider auth list - empty result"""
+ mock_session.query.return_value.where.return_value.all.return_value = []
+
+ result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
+
+ assert result == []
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_provider_auth_list_filters_disabled(self, mock_session):
+ """Test get provider auth list - filters disabled items"""
+ mock_session.query.return_value.where.return_value.all.return_value = []
+
+ ApiKeyAuthService.get_provider_auth_list(self.tenant_id)
+
+ # Verify where conditions include disabled.is_(False)
+ where_call = mock_session.query.return_value.where.call_args[0]
+ assert len(where_call) == 2 # tenant_id and disabled filter conditions
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ @patch("services.auth.api_key_auth_service.encrypter")
+ def test_create_provider_auth_success(self, mock_encrypter, mock_factory, mock_session):
+ """Test create provider auth - success scenario"""
+ # Mock successful auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = True
+ mock_factory.return_value = mock_auth_instance
+
+ # Mock encryption
+ encrypted_key = "encrypted_test_key_123"
+ mock_encrypter.encrypt_token.return_value = encrypted_key
+
+ # Mock database operations
+ mock_session.add = Mock()
+ mock_session.commit = Mock()
+
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ # Verify factory class calls
+ mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
+ mock_auth_instance.validate_credentials.assert_called_once()
+
+ # Verify encryption calls
+ mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, "test_secret_key_123")
+
+ # Verify database operations
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_called_once()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ def test_create_provider_auth_validation_failed(self, mock_factory, mock_session):
+ """Test create provider auth - validation failed"""
+ # Mock failed auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = False
+ mock_factory.return_value = mock_auth_instance
+
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ # Verify no database operations when validation fails
+ mock_session.add.assert_not_called()
+ mock_session.commit.assert_not_called()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ @patch("services.auth.api_key_auth_service.encrypter")
+ def test_create_provider_auth_encrypts_api_key(self, mock_encrypter, mock_factory, mock_session):
+ """Test create provider auth - ensures API key is encrypted"""
+ # Mock successful auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = True
+ mock_factory.return_value = mock_auth_instance
+
+ # Mock encryption
+ encrypted_key = "encrypted_test_key_123"
+ mock_encrypter.encrypt_token.return_value = encrypted_key
+
+ # Mock database operations
+ mock_session.add = Mock()
+ mock_session.commit = Mock()
+
+ args_copy = self.mock_args.copy()
+ original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
+
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
+
+ # Verify original key is replaced with encrypted key
+ assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
+ assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
+
+ # Verify encryption function is called correctly
+ mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_success(self, mock_session):
+ """Test get auth credentials - success scenario"""
+ # Mock database query result
+ mock_binding = Mock()
+ mock_binding.credentials = json.dumps(self.mock_credentials)
+ mock_session.query.return_value.where.return_value.first.return_value = mock_binding
+ mock_session.query.return_value.where.return_value.first.return_value = mock_binding
+
+ result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ assert result == self.mock_credentials
+ mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_not_found(self, mock_session):
+ """Test get auth credentials - not found"""
+ mock_session.query.return_value.where.return_value.first.return_value = None
+
+ result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ assert result is None
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_filters_correctly(self, mock_session):
+ """Test get auth credentials - applies correct filters"""
+ mock_session.query.return_value.where.return_value.first.return_value = None
+
+ ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ # Verify where conditions are correct
+ where_call = mock_session.query.return_value.where.call_args[0]
+ assert len(where_call) == 4 # tenant_id, category, provider, disabled
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_json_parsing(self, mock_session):
+ """Test get auth credentials - JSON parsing"""
+ # Mock credentials with special characters
+ special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}}
+
+ mock_binding = Mock()
+ mock_binding.credentials = json.dumps(special_credentials, ensure_ascii=False)
+ mock_session.query.return_value.where.return_value.first.return_value = mock_binding
+
+ result = ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ assert result == special_credentials
+ assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_delete_provider_auth_success(self, mock_session):
+ """Test delete provider auth - success scenario"""
+ # Mock database query result
+ mock_binding = Mock()
+ mock_session.query.return_value.where.return_value.first.return_value = mock_binding
+
+ ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
+
+ # Verify delete operations
+ mock_session.delete.assert_called_once_with(mock_binding)
+ mock_session.commit.assert_called_once()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_delete_provider_auth_not_found(self, mock_session):
+ """Test delete provider auth - not found"""
+ mock_session.query.return_value.where.return_value.first.return_value = None
+
+ ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
+
+ # Verify no delete operations when not found
+ mock_session.delete.assert_not_called()
+ mock_session.commit.assert_not_called()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_delete_provider_auth_filters_by_tenant(self, mock_session):
+ """Test delete provider auth - filters by tenant"""
+ mock_session.query.return_value.where.return_value.first.return_value = None
+
+ ApiKeyAuthService.delete_provider_auth(self.tenant_id, self.binding_id)
+
+ # Verify where conditions include tenant_id and binding_id
+ where_call = mock_session.query.return_value.where.call_args[0]
+ assert len(where_call) == 2
+
+ def test_validate_api_key_auth_args_success(self):
+ """Test API key auth args validation - success scenario"""
+ # Should not raise any exception
+ ApiKeyAuthService.validate_api_key_auth_args(self.mock_args)
+
+ def test_validate_api_key_auth_args_missing_category(self):
+ """Test API key auth args validation - missing category"""
+ args = self.mock_args.copy()
+ del args["category"]
+
+ with pytest.raises(ValueError, match="category is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_empty_category(self):
+ """Test API key auth args validation - empty category"""
+ args = self.mock_args.copy()
+ args["category"] = ""
+
+ with pytest.raises(ValueError, match="category is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_missing_provider(self):
+ """Test API key auth args validation - missing provider"""
+ args = self.mock_args.copy()
+ del args["provider"]
+
+ with pytest.raises(ValueError, match="provider is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_empty_provider(self):
+ """Test API key auth args validation - empty provider"""
+ args = self.mock_args.copy()
+ args["provider"] = ""
+
+ with pytest.raises(ValueError, match="provider is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_missing_credentials(self):
+ """Test API key auth args validation - missing credentials"""
+ args = self.mock_args.copy()
+ del args["credentials"]
+
+ with pytest.raises(ValueError, match="credentials is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_empty_credentials(self):
+ """Test API key auth args validation - empty credentials"""
+ args = self.mock_args.copy()
+ args["credentials"] = None # type: ignore
+
+ with pytest.raises(ValueError, match="credentials is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_invalid_credentials_type(self):
+ """Test API key auth args validation - invalid credentials type"""
+ args = self.mock_args.copy()
+ args["credentials"] = "not_a_dict"
+
+ with pytest.raises(ValueError, match="credentials must be a dictionary"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_missing_auth_type(self):
+ """Test API key auth args validation - missing auth_type"""
+ args = self.mock_args.copy()
+ del args["credentials"]["auth_type"] # type: ignore
+
+ with pytest.raises(ValueError, match="auth_type is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ def test_validate_api_key_auth_args_empty_auth_type(self):
+ """Test API key auth args validation - empty auth_type"""
+ args = self.mock_args.copy()
+ args["credentials"]["auth_type"] = "" # type: ignore
+
+ with pytest.raises(ValueError, match="auth_type is required"):
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ @pytest.mark.parametrize(
+ "malicious_input",
+ [
+ "",
+ "'; DROP TABLE users; --",
+ "../../../etc/passwd",
+ "\\x00\\x00", # null bytes
+ "A" * 10000, # very long input
+ ],
+ )
+ def test_validate_api_key_auth_args_malicious_input(self, malicious_input):
+ """Test API key auth args validation - malicious input"""
+ args = self.mock_args.copy()
+ args["category"] = malicious_input
+
+ # Verify parameter validator doesn't crash on malicious input
+ # Should validate normally rather than raising security-related exceptions
+ ApiKeyAuthService.validate_api_key_auth_args(args)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ @patch("services.auth.api_key_auth_service.encrypter")
+ def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session):
+ """Test create provider auth - database error handling"""
+ # Mock successful auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = True
+ mock_factory.return_value = mock_auth_instance
+
+ # Mock encryption
+ mock_encrypter.encrypt_token.return_value = "encrypted_key"
+
+ # Mock database error
+ mock_session.commit.side_effect = Exception("Database error")
+
+ with pytest.raises(Exception, match="Database error"):
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_get_auth_credentials_invalid_json(self, mock_session):
+ """Test get auth credentials - invalid JSON"""
+ # Mock database returning invalid JSON
+ mock_binding = Mock()
+ mock_binding.credentials = "invalid json content"
+ mock_session.query.return_value.where.return_value.first.return_value = mock_binding
+
+ with pytest.raises(json.JSONDecodeError):
+ ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ def test_create_provider_auth_factory_exception(self, mock_factory, mock_session):
+ """Test create provider auth - factory exception"""
+ # Mock factory raising exception
+ mock_factory.side_effect = Exception("Factory error")
+
+ with pytest.raises(Exception, match="Factory error"):
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
+ @patch("services.auth.api_key_auth_service.encrypter")
+ def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session):
+ """Test create provider auth - encryption exception"""
+ # Mock successful auth validation
+ mock_auth_instance = Mock()
+ mock_auth_instance.validate_credentials.return_value = True
+ mock_factory.return_value = mock_auth_instance
+
+ # Mock encryption exception
+ mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
+
+ with pytest.raises(Exception, match="Encryption error"):
+ ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
+
+ def test_validate_api_key_auth_args_none_input(self):
+ """Test API key auth args validation - None input"""
+ with pytest.raises(TypeError):
+ ApiKeyAuthService.validate_api_key_auth_args(None)
+
+ def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
+ """Test API key auth args validation - dict credentials with list auth_type"""
+ args = self.mock_args.copy()
+ args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
+
+ # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
+ # So this should not raise exception, this test should pass
+ ApiKeyAuthService.validate_api_key_auth_args(args)
diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py
new file mode 100644
index 0000000000..4ce5525942
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_auth_integration.py
@@ -0,0 +1,234 @@
+"""
+API Key Authentication System Integration Tests
+"""
+
+import json
+from concurrent.futures import ThreadPoolExecutor
+from unittest.mock import Mock, patch
+
+import pytest
+import requests
+
+from services.auth.api_key_auth_factory import ApiKeyAuthFactory
+from services.auth.api_key_auth_service import ApiKeyAuthService
+from services.auth.auth_type import AuthType
+
+
+class TestAuthIntegration:
+ def setup_method(self):
+ self.tenant_id_1 = "tenant_123"
+ self.tenant_id_2 = "tenant_456" # For multi-tenant isolation testing
+ self.category = "search"
+
+ # Realistic authentication configurations
+ self.firecrawl_credentials = {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}
+ self.jina_credentials = {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}}
+ self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ @patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
+ def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
+ """Test complete authentication flow: request → validation → encryption → storage"""
+ mock_http.return_value = self._create_success_response()
+ mock_encrypt.return_value = "encrypted_fc_test_key_123"
+ mock_session.add = Mock()
+ mock_session.commit = Mock()
+
+ args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
+ ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
+
+ mock_http.assert_called_once()
+ call_args = mock_http.call_args
+ assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0]
+ assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123"
+
+ mock_encrypt.assert_called_once_with(self.tenant_id_1, "fc_test_key_123")
+ mock_session.add.assert_called_once()
+ mock_session.commit.assert_called_once()
+
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_cross_component_integration(self, mock_http):
+ """Test factory → provider → HTTP call integration"""
+ mock_http.return_value = self._create_success_response()
+ factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
+ result = factory.validate_credentials()
+
+ assert result is True
+ mock_http.assert_called_once()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_multi_tenant_isolation(self, mock_session):
+ """Ensure complete tenant data isolation"""
+ tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
+ tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)
+
+ mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
+ result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)
+
+ mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
+ result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)
+
+ assert len(result1) == 1
+ assert result1[0].tenant_id == self.tenant_id_1
+ assert len(result2) == 1
+ assert result2[0].tenant_id == self.tenant_id_2
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ def test_cross_tenant_access_prevention(self, mock_session):
+ """Test prevention of cross-tenant credential access"""
+ mock_session.query.return_value.where.return_value.first.return_value = None
+
+ result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL)
+
+ assert result is None
+
+ def test_sensitive_data_protection(self):
+ """Ensure API keys don't leak to logs"""
+ credentials_with_secrets = {
+ "auth_type": "bearer",
+ "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"},
+ }
+
+ factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets)
+ factory_str = str(factory)
+
+ assert "super_secret_key_do_not_log" not in factory_str
+ assert "another_secret" not in factory_str
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ @patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
+ def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
+ """Test concurrent authentication creation safety"""
+ mock_http.return_value = self._create_success_response()
+ mock_encrypt.return_value = "encrypted_key"
+ mock_session.add = Mock()
+ mock_session.commit = Mock()
+
+ args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
+
+ results = []
+ exceptions = []
+
+ def create_auth():
+ try:
+ ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
+ results.append("success")
+ except Exception as e:
+ exceptions.append(e)
+
+ with ThreadPoolExecutor(max_workers=5) as executor:
+ futures = [executor.submit(create_auth) for _ in range(5)]
+ for future in futures:
+ future.result()
+
+ assert len(results) == 5
+ assert len(exceptions) == 0
+ assert mock_session.add.call_count == 5
+ assert mock_session.commit.call_count == 5
+
+ @pytest.mark.parametrize(
+ "invalid_input",
+ [
+ None, # Null input
+ {}, # Empty dictionary - missing required fields
+ {"auth_type": "bearer"}, # Missing config section
+ {"auth_type": "bearer", "config": {}}, # Missing api_key
+ ],
+ )
+ def test_invalid_input_boundary(self, invalid_input):
+ """Test boundary handling for invalid inputs"""
+ with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
+ ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
+
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_http_error_handling(self, mock_http):
+ """Test proper HTTP error handling"""
+ mock_response = Mock()
+ mock_response.status_code = 401
+ mock_response.text = '{"error": "Unauthorized"}'
+ mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized")
+ mock_http.return_value = mock_response
+
+ # PT012: Split into single statement for pytest.raises
+ factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
+ with pytest.raises((requests.exceptions.HTTPError, Exception)):
+ factory.validate_credentials()
+
+ @patch("services.auth.api_key_auth_service.db.session")
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_network_failure_recovery(self, mock_http, mock_session):
+ """Test system recovery from network failures"""
+ mock_http.side_effect = requests.exceptions.RequestException("Network timeout")
+ mock_session.add = Mock()
+ mock_session.commit = Mock()
+
+ args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
+
+ with pytest.raises(requests.exceptions.RequestException):
+ ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
+
+ mock_session.commit.assert_not_called()
+
+ @pytest.mark.parametrize(
+ ("provider", "credentials"),
+ [
+ (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}),
+ (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}),
+ (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}),
+ ],
+ )
+ def test_all_providers_factory_creation(self, provider, credentials):
+ """Test factory creation for all supported providers"""
+ try:
+ auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
+ assert auth_class is not None
+
+ factory = ApiKeyAuthFactory(provider, credentials)
+ assert factory.auth is not None
+ except ImportError:
+ pytest.skip(f"Provider {provider} not implemented yet")
+
+ def _create_success_response(self, status_code=200):
+ """Create successful HTTP response mock"""
+ mock_response = Mock()
+ mock_response.status_code = status_code
+ mock_response.json.return_value = {"status": "success"}
+ mock_response.raise_for_status.return_value = None
+ return mock_response
+
+ def _create_mock_binding(self, tenant_id: str, provider: str, credentials: dict) -> Mock:
+ """Create realistic database binding mock"""
+ mock_binding = Mock()
+ mock_binding.id = f"binding_{provider}_{tenant_id}"
+ mock_binding.tenant_id = tenant_id
+ mock_binding.category = self.category
+ mock_binding.provider = provider
+ mock_binding.credentials = json.dumps(credentials, ensure_ascii=False)
+ mock_binding.disabled = False
+
+ mock_binding.created_at = Mock()
+ mock_binding.created_at.timestamp.return_value = 1640995200
+ mock_binding.updated_at = Mock()
+ mock_binding.updated_at.timestamp.return_value = 1640995200
+
+ return mock_binding
+
+ def test_integration_coverage_validation(self):
+ """Validate integration test coverage meets quality standards"""
+ core_scenarios = {
+ "business_logic": ["end_to_end_auth_flow", "cross_component_integration"],
+ "security": ["multi_tenant_isolation", "cross_tenant_access_prevention", "sensitive_data_protection"],
+ "reliability": ["concurrent_creation_safety", "network_failure_recovery"],
+ "compatibility": ["all_providers_factory_creation"],
+ "boundaries": ["invalid_input_boundary", "http_error_handling"],
+ }
+
+ total_scenarios = sum(len(scenarios) for scenarios in core_scenarios.values())
+ assert total_scenarios >= 10
+
+ security_tests = core_scenarios["security"]
+ assert "multi_tenant_isolation" in security_tests
+ assert "sensitive_data_protection" in security_tests
+ assert True
diff --git a/api/tests/unit_tests/services/auth/test_auth_type.py b/api/tests/unit_tests/services/auth/test_auth_type.py
new file mode 100644
index 0000000000..94073f451e
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_auth_type.py
@@ -0,0 +1,150 @@
+import pytest
+
+from services.auth.auth_type import AuthType
+
+
+class TestAuthType:
+ """Test cases for AuthType enum"""
+
+ def test_auth_type_is_str_enum(self):
+ """Test that AuthType is properly a StrEnum"""
+ assert issubclass(AuthType, str)
+ assert hasattr(AuthType, "__members__")
+
+ def test_auth_type_has_expected_values(self):
+ """Test that all expected auth types exist with correct values"""
+ expected_values = {
+ "FIRECRAWL": "firecrawl",
+ "WATERCRAWL": "watercrawl",
+ "JINA": "jinareader",
+ }
+
+ # Verify all expected members exist
+ for member_name, expected_value in expected_values.items():
+ assert hasattr(AuthType, member_name)
+ assert getattr(AuthType, member_name).value == expected_value
+
+ # Verify no extra members exist
+ assert len(AuthType) == len(expected_values)
+
+ @pytest.mark.parametrize(
+ ("auth_type", "expected_string"),
+ [
+ (AuthType.FIRECRAWL, "firecrawl"),
+ (AuthType.WATERCRAWL, "watercrawl"),
+ (AuthType.JINA, "jinareader"),
+ ],
+ )
+ def test_auth_type_string_representation(self, auth_type, expected_string):
+ """Test string representation of auth types"""
+ assert str(auth_type) == expected_string
+ assert auth_type.value == expected_string
+
+ @pytest.mark.parametrize(
+ ("auth_type", "compare_value", "expected_result"),
+ [
+ (AuthType.FIRECRAWL, "firecrawl", True),
+ (AuthType.WATERCRAWL, "watercrawl", True),
+ (AuthType.JINA, "jinareader", True),
+ (AuthType.FIRECRAWL, "FIRECRAWL", False), # Case sensitive
+ (AuthType.FIRECRAWL, "watercrawl", False),
+ (AuthType.JINA, "jina", False), # Full value mismatch
+ ],
+ )
+ def test_auth_type_comparison(self, auth_type, compare_value, expected_result):
+ """Test auth type comparison with strings"""
+ assert (auth_type == compare_value) is expected_result
+
+ def test_auth_type_iteration(self):
+ """Test that AuthType can be iterated over"""
+ auth_types = list(AuthType)
+ assert len(auth_types) == 3
+ assert AuthType.FIRECRAWL in auth_types
+ assert AuthType.WATERCRAWL in auth_types
+ assert AuthType.JINA in auth_types
+
+ def test_auth_type_membership(self):
+ """Test membership checking for AuthType"""
+ assert "firecrawl" in [auth.value for auth in AuthType]
+ assert "watercrawl" in [auth.value for auth in AuthType]
+ assert "jinareader" in [auth.value for auth in AuthType]
+ assert "invalid" not in [auth.value for auth in AuthType]
+
+ def test_auth_type_invalid_attribute_access(self):
+ """Test accessing non-existent auth type raises AttributeError"""
+ with pytest.raises(AttributeError):
+ _ = AuthType.INVALID_TYPE
+
+ def test_auth_type_immutability(self):
+ """Test that enum values cannot be modified"""
+ # In Python 3.11+, enum members are read-only
+ with pytest.raises(AttributeError):
+ AuthType.FIRECRAWL = "modified"
+
+ def test_auth_type_from_value(self):
+ """Test creating AuthType from string value"""
+ assert AuthType("firecrawl") == AuthType.FIRECRAWL
+ assert AuthType("watercrawl") == AuthType.WATERCRAWL
+ assert AuthType("jinareader") == AuthType.JINA
+
+ # Test invalid value
+ with pytest.raises(ValueError) as exc_info:
+ AuthType("invalid_auth_type")
+ assert "invalid_auth_type" in str(exc_info.value)
+
+ def test_auth_type_name_property(self):
+ """Test the name property of enum members"""
+ assert AuthType.FIRECRAWL.name == "FIRECRAWL"
+ assert AuthType.WATERCRAWL.name == "WATERCRAWL"
+ assert AuthType.JINA.name == "JINA"
+
+ @pytest.mark.parametrize(
+ "auth_type",
+ [AuthType.FIRECRAWL, AuthType.WATERCRAWL, AuthType.JINA],
+ )
+ def test_auth_type_isinstance_checks(self, auth_type):
+ """Test isinstance checks for auth types"""
+ assert isinstance(auth_type, AuthType)
+ assert isinstance(auth_type, str)
+ assert isinstance(auth_type.value, str)
+
+ def test_auth_type_hash(self):
+ """Test that auth types are hashable and can be used in sets/dicts"""
+ auth_set = {AuthType.FIRECRAWL, AuthType.WATERCRAWL, AuthType.JINA}
+ assert len(auth_set) == 3
+
+ auth_dict = {
+ AuthType.FIRECRAWL: "firecrawl_handler",
+ AuthType.WATERCRAWL: "watercrawl_handler",
+ AuthType.JINA: "jina_handler",
+ }
+ assert auth_dict[AuthType.FIRECRAWL] == "firecrawl_handler"
+
+ def test_auth_type_json_serializable(self):
+ """Test that auth types can be JSON serialized"""
+ import json
+
+ auth_data = {
+ "provider": AuthType.FIRECRAWL,
+ "enabled": True,
+ }
+
+ # Should serialize to string value
+ json_str = json.dumps(auth_data, default=str)
+ assert '"provider": "firecrawl"' in json_str
+
+ def test_auth_type_matches_factory_usage(self):
+ """Test that all AuthType values are handled by ApiKeyAuthFactory"""
+ # This test verifies that the enum values match what's expected
+ # by the factory implementation
+ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
+
+ for auth_type in AuthType:
+ # Should not raise ValueError for valid auth types
+ try:
+ auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(auth_type)
+ assert auth_class is not None
+ except ImportError:
+ # It's OK if the actual auth implementation doesn't exist
+ # We're just testing that the enum value is recognized
+ pass
diff --git a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py
new file mode 100644
index 0000000000..ffdf5897ed
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py
@@ -0,0 +1,191 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+import requests
+
+from services.auth.firecrawl.firecrawl import FirecrawlAuth
+
+
+class TestFirecrawlAuth:
+ @pytest.fixture
+ def valid_credentials(self):
+ """Fixture for valid bearer credentials"""
+ return {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+
+ @pytest.fixture
+ def auth_instance(self, valid_credentials):
+ """Fixture for FirecrawlAuth instance with valid credentials"""
+ return FirecrawlAuth(valid_credentials)
+
+ def test_should_initialize_with_valid_bearer_credentials(self, valid_credentials):
+ """Test successful initialization with valid bearer credentials"""
+ auth = FirecrawlAuth(valid_credentials)
+ assert auth.api_key == "test_api_key_123"
+ assert auth.base_url == "https://api.firecrawl.dev"
+ assert auth.credentials == valid_credentials
+
+ def test_should_initialize_with_custom_base_url(self):
+ """Test initialization with custom base URL"""
+ credentials = {
+ "auth_type": "bearer",
+ "config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
+ }
+ auth = FirecrawlAuth(credentials)
+ assert auth.api_key == "test_api_key_123"
+ assert auth.base_url == "https://custom.firecrawl.dev"
+
+ @pytest.mark.parametrize(
+ ("auth_type", "expected_error"),
+ [
+ ("basic", "Invalid auth type, Firecrawl auth type must be Bearer"),
+ ("x-api-key", "Invalid auth type, Firecrawl auth type must be Bearer"),
+ ("", "Invalid auth type, Firecrawl auth type must be Bearer"),
+ ],
+ )
+ def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
+ """Test that non-bearer auth types raise ValueError"""
+ credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
+ with pytest.raises(ValueError) as exc_info:
+ FirecrawlAuth(credentials)
+ assert str(exc_info.value) == expected_error
+
+ @pytest.mark.parametrize(
+ ("credentials", "expected_error"),
+ [
+ ({"auth_type": "bearer", "config": {}}, "No API key provided"),
+ ({"auth_type": "bearer"}, "No API key provided"),
+ ({"auth_type": "bearer", "config": {"api_key": ""}}, "No API key provided"),
+ ({"auth_type": "bearer", "config": {"api_key": None}}, "No API key provided"),
+ ],
+ )
+ def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
+ """Test that missing or empty API key raises ValueError"""
+ with pytest.raises(ValueError) as exc_info:
+ FirecrawlAuth(credentials)
+ assert str(exc_info.value) == expected_error
+
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
+ """Test successful credential validation"""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+
+ result = auth_instance.validate_credentials()
+
+ assert result is True
+ expected_data = {
+ "url": "https://example.com",
+ "includePaths": [],
+ "excludePaths": [],
+ "limit": 1,
+ "scrapeOptions": {"onlyMainContent": True},
+ }
+ mock_post.assert_called_once_with(
+ "https://api.firecrawl.dev/v1/crawl",
+ headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
+ json=expected_data,
+ )
+
+ @pytest.mark.parametrize(
+ ("status_code", "error_message"),
+ [
+ (402, "Payment required"),
+ (409, "Conflict error"),
+ (500, "Internal server error"),
+ ],
+ )
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
+ """Test handling of various HTTP error codes"""
+ mock_response = MagicMock()
+ mock_response.status_code = status_code
+ mock_response.json.return_value = {"error": error_message}
+ mock_post.return_value = mock_response
+
+ with pytest.raises(Exception) as exc_info:
+ auth_instance.validate_credentials()
+ assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
+
+ @pytest.mark.parametrize(
+ ("status_code", "response_text", "has_json_error", "expected_error_contains"),
+ [
+ (403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
+ (404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
+ (401, "Not JSON", True, "Expecting value"), # JSON decode error
+ ],
+ )
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_should_handle_unexpected_errors(
+ self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
+ ):
+ """Test handling of unexpected errors with various response formats"""
+ mock_response = MagicMock()
+ mock_response.status_code = status_code
+ mock_response.text = response_text
+ if has_json_error:
+ mock_response.json.side_effect = Exception("Not JSON")
+ mock_post.return_value = mock_response
+
+ with pytest.raises(Exception) as exc_info:
+ auth_instance.validate_credentials()
+ assert expected_error_contains in str(exc_info.value)
+
+ @pytest.mark.parametrize(
+ ("exception_type", "exception_message"),
+ [
+ (requests.ConnectionError, "Network error"),
+ (requests.Timeout, "Request timeout"),
+ (requests.ReadTimeout, "Read timeout"),
+ (requests.ConnectTimeout, "Connection timeout"),
+ ],
+ )
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
+ """Test handling of various network-related errors including timeouts"""
+ mock_post.side_effect = exception_type(exception_message)
+
+ with pytest.raises(exception_type) as exc_info:
+ auth_instance.validate_credentials()
+ assert exception_message in str(exc_info.value)
+
+ def test_should_not_expose_api_key_in_error_messages(self):
+ """Test that API key is not exposed in error messages"""
+ credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
+ auth = FirecrawlAuth(credentials)
+
+ # Verify API key is stored but not in any error message
+ assert auth.api_key == "super_secret_key_12345"
+
+ # Test various error scenarios don't expose the key
+ with pytest.raises(ValueError) as exc_info:
+ FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
+ assert "super_secret_key_12345" not in str(exc_info.value)
+
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_should_use_custom_base_url_in_validation(self, mock_post):
+ """Test that custom base URL is used in validation"""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+
+ credentials = {
+ "auth_type": "bearer",
+ "config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"},
+ }
+ auth = FirecrawlAuth(credentials)
+ result = auth.validate_credentials()
+
+ assert result is True
+ assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
+
+ @patch("services.auth.firecrawl.firecrawl.requests.post")
+ def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
+ """Test that timeout errors are handled gracefully with appropriate error message"""
+ mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds")
+
+ with pytest.raises(requests.Timeout) as exc_info:
+ auth_instance.validate_credentials()
+
+ # Verify the timeout exception is raised with original message
+ assert "timed out" in str(exc_info.value)
diff --git a/api/tests/unit_tests/services/auth/test_jina_auth.py b/api/tests/unit_tests/services/auth/test_jina_auth.py
new file mode 100644
index 0000000000..ccbca5a36f
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_jina_auth.py
@@ -0,0 +1,155 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+import requests
+
+from services.auth.jina.jina import JinaAuth
+
+
+class TestJinaAuth:
+ def test_should_initialize_with_valid_bearer_credentials(self):
+ """Test successful initialization with valid bearer credentials"""
+ credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+ auth = JinaAuth(credentials)
+ assert auth.api_key == "test_api_key_123"
+ assert auth.credentials == credentials
+
+ def test_should_raise_error_for_invalid_auth_type(self):
+ """Test that non-bearer auth type raises ValueError"""
+ credentials = {"auth_type": "basic", "config": {"api_key": "test_api_key_123"}}
+ with pytest.raises(ValueError) as exc_info:
+ JinaAuth(credentials)
+ assert str(exc_info.value) == "Invalid auth type, Jina Reader auth type must be Bearer"
+
+ def test_should_raise_error_for_missing_api_key(self):
+ """Test that missing API key raises ValueError"""
+ credentials = {"auth_type": "bearer", "config": {}}
+ with pytest.raises(ValueError) as exc_info:
+ JinaAuth(credentials)
+ assert str(exc_info.value) == "No API key provided"
+
+ def test_should_raise_error_for_missing_config(self):
+ """Test that missing config section raises ValueError"""
+ credentials = {"auth_type": "bearer"}
+ with pytest.raises(ValueError) as exc_info:
+ JinaAuth(credentials)
+ assert str(exc_info.value) == "No API key provided"
+
+ @patch("services.auth.jina.jina.requests.post")
+ def test_should_validate_valid_credentials_successfully(self, mock_post):
+ """Test successful credential validation"""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_post.return_value = mock_response
+
+ credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+ auth = JinaAuth(credentials)
+ result = auth.validate_credentials()
+
+ assert result is True
+ mock_post.assert_called_once_with(
+ "https://r.jina.ai",
+ headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"},
+ json={"url": "https://example.com"},
+ )
+
+ @patch("services.auth.jina.jina.requests.post")
+ def test_should_handle_http_402_error(self, mock_post):
+ """Test handling of 402 Payment Required error"""
+ mock_response = MagicMock()
+ mock_response.status_code = 402
+ mock_response.json.return_value = {"error": "Payment required"}
+ mock_post.return_value = mock_response
+
+ credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+ auth = JinaAuth(credentials)
+
+ with pytest.raises(Exception) as exc_info:
+ auth.validate_credentials()
+ assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
+
+ @patch("services.auth.jina.jina.requests.post")
+ def test_should_handle_http_409_error(self, mock_post):
+ """Test handling of 409 Conflict error"""
+ mock_response = MagicMock()
+ mock_response.status_code = 409
+ mock_response.json.return_value = {"error": "Conflict error"}
+ mock_post.return_value = mock_response
+
+ credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+ auth = JinaAuth(credentials)
+
+ with pytest.raises(Exception) as exc_info:
+ auth.validate_credentials()
+ assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
+
+ @patch("services.auth.jina.jina.requests.post")
+ def test_should_handle_http_500_error(self, mock_post):
+ """Test handling of 500 Internal Server Error"""
+ mock_response = MagicMock()
+ mock_response.status_code = 500
+ mock_response.json.return_value = {"error": "Internal server error"}
+ mock_post.return_value = mock_response
+
+ credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+ auth = JinaAuth(credentials)
+
+ with pytest.raises(Exception) as exc_info:
+ auth.validate_credentials()
+ assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
+
+ @patch("services.auth.jina.jina.requests.post")
+ def test_should_handle_unexpected_error_with_text_response(self, mock_post):
+ """Test handling of unexpected errors with text response"""
+ mock_response = MagicMock()
+ mock_response.status_code = 403
+ mock_response.text = '{"error": "Forbidden"}'
+ mock_response.json.side_effect = Exception("Not JSON")
+ mock_post.return_value = mock_response
+
+ credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+ auth = JinaAuth(credentials)
+
+ with pytest.raises(Exception) as exc_info:
+ auth.validate_credentials()
+ assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
+
+ @patch("services.auth.jina.jina.requests.post")
+ def test_should_handle_unexpected_error_without_text(self, mock_post):
+ """Test handling of unexpected errors without text response"""
+ mock_response = MagicMock()
+ mock_response.status_code = 404
+ mock_response.text = ""
+ mock_response.json.side_effect = Exception("Not JSON")
+ mock_post.return_value = mock_response
+
+ credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+ auth = JinaAuth(credentials)
+
+ with pytest.raises(Exception) as exc_info:
+ auth.validate_credentials()
+ assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
+
+ @patch("services.auth.jina.jina.requests.post")
+ def test_should_handle_network_errors(self, mock_post):
+ """Test handling of network connection errors"""
+ mock_post.side_effect = requests.ConnectionError("Network error")
+
+ credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
+ auth = JinaAuth(credentials)
+
+ with pytest.raises(requests.ConnectionError):
+ auth.validate_credentials()
+
+ def test_should_not_expose_api_key_in_error_messages(self):
+ """Test that API key is not exposed in error messages"""
+ credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}
+ auth = JinaAuth(credentials)
+
+ # Verify API key is stored but not in any error message
+ assert auth.api_key == "super_secret_key_12345"
+
+ # Test various error scenarios don't expose the key
+ with pytest.raises(ValueError) as exc_info:
+ JinaAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
+ assert "super_secret_key_12345" not in str(exc_info.value)
diff --git a/api/tests/unit_tests/services/auth/test_watercrawl_auth.py b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py
new file mode 100644
index 0000000000..bacf0b24ea
--- /dev/null
+++ b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py
@@ -0,0 +1,205 @@
+from unittest.mock import MagicMock, patch
+
+import pytest
+import requests
+
+from services.auth.watercrawl.watercrawl import WatercrawlAuth
+
+
+class TestWatercrawlAuth:
+ @pytest.fixture
+ def valid_credentials(self):
+ """Fixture for valid x-api-key credentials"""
+ return {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123"}}
+
+ @pytest.fixture
+ def auth_instance(self, valid_credentials):
+ """Fixture for WatercrawlAuth instance with valid credentials"""
+ return WatercrawlAuth(valid_credentials)
+
+ def test_should_initialize_with_valid_x_api_key_credentials(self, valid_credentials):
+ """Test successful initialization with valid x-api-key credentials"""
+ auth = WatercrawlAuth(valid_credentials)
+ assert auth.api_key == "test_api_key_123"
+ assert auth.base_url == "https://app.watercrawl.dev"
+ assert auth.credentials == valid_credentials
+
+ def test_should_initialize_with_custom_base_url(self):
+ """Test initialization with custom base URL"""
+ credentials = {
+ "auth_type": "x-api-key",
+ "config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
+ }
+ auth = WatercrawlAuth(credentials)
+ assert auth.api_key == "test_api_key_123"
+ assert auth.base_url == "https://custom.watercrawl.dev"
+
+ @pytest.mark.parametrize(
+ ("auth_type", "expected_error"),
+ [
+ ("bearer", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
+ ("basic", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
+ ("", "Invalid auth type, WaterCrawl auth type must be x-api-key"),
+ ],
+ )
+ def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error):
+ """Test that non-x-api-key auth types raise ValueError"""
+ credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}}
+ with pytest.raises(ValueError) as exc_info:
+ WatercrawlAuth(credentials)
+ assert str(exc_info.value) == expected_error
+
+ @pytest.mark.parametrize(
+ ("credentials", "expected_error"),
+ [
+ ({"auth_type": "x-api-key", "config": {}}, "No API key provided"),
+ ({"auth_type": "x-api-key"}, "No API key provided"),
+ ({"auth_type": "x-api-key", "config": {"api_key": ""}}, "No API key provided"),
+ ({"auth_type": "x-api-key", "config": {"api_key": None}}, "No API key provided"),
+ ],
+ )
+ def test_should_raise_error_for_missing_api_key(self, credentials, expected_error):
+ """Test that missing or empty API key raises ValueError"""
+ with pytest.raises(ValueError) as exc_info:
+ WatercrawlAuth(credentials)
+ assert str(exc_info.value) == expected_error
+
+ @patch("services.auth.watercrawl.watercrawl.requests.get")
+ def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
+ """Test successful credential validation"""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_get.return_value = mock_response
+
+ result = auth_instance.validate_credentials()
+
+ assert result is True
+ mock_get.assert_called_once_with(
+ "https://app.watercrawl.dev/api/v1/core/crawl-requests/",
+ headers={"Content-Type": "application/json", "X-API-KEY": "test_api_key_123"},
+ )
+
+ @pytest.mark.parametrize(
+ ("status_code", "error_message"),
+ [
+ (402, "Payment required"),
+ (409, "Conflict error"),
+ (500, "Internal server error"),
+ ],
+ )
+ @patch("services.auth.watercrawl.watercrawl.requests.get")
+ def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
+ """Test handling of various HTTP error codes"""
+ mock_response = MagicMock()
+ mock_response.status_code = status_code
+ mock_response.json.return_value = {"error": error_message}
+ mock_get.return_value = mock_response
+
+ with pytest.raises(Exception) as exc_info:
+ auth_instance.validate_credentials()
+ assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}"
+
+ @pytest.mark.parametrize(
+ ("status_code", "response_text", "has_json_error", "expected_error_contains"),
+ [
+ (403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"),
+ (404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"),
+ (401, "Not JSON", True, "Expecting value"), # JSON decode error
+ ],
+ )
+ @patch("services.auth.watercrawl.watercrawl.requests.get")
+ def test_should_handle_unexpected_errors(
+ self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
+ ):
+ """Test handling of unexpected errors with various response formats"""
+ mock_response = MagicMock()
+ mock_response.status_code = status_code
+ mock_response.text = response_text
+ if has_json_error:
+ mock_response.json.side_effect = Exception("Not JSON")
+ mock_get.return_value = mock_response
+
+ with pytest.raises(Exception) as exc_info:
+ auth_instance.validate_credentials()
+ assert expected_error_contains in str(exc_info.value)
+
+ @pytest.mark.parametrize(
+ ("exception_type", "exception_message"),
+ [
+ (requests.ConnectionError, "Network error"),
+ (requests.Timeout, "Request timeout"),
+ (requests.ReadTimeout, "Read timeout"),
+ (requests.ConnectTimeout, "Connection timeout"),
+ ],
+ )
+ @patch("services.auth.watercrawl.watercrawl.requests.get")
+ def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
+ """Test handling of various network-related errors including timeouts"""
+ mock_get.side_effect = exception_type(exception_message)
+
+ with pytest.raises(exception_type) as exc_info:
+ auth_instance.validate_credentials()
+ assert exception_message in str(exc_info.value)
+
+ def test_should_not_expose_api_key_in_error_messages(self):
+ """Test that API key is not exposed in error messages"""
+ credentials = {"auth_type": "x-api-key", "config": {"api_key": "super_secret_key_12345"}}
+ auth = WatercrawlAuth(credentials)
+
+ # Verify API key is stored but not in any error message
+ assert auth.api_key == "super_secret_key_12345"
+
+ # Test various error scenarios don't expose the key
+ with pytest.raises(ValueError) as exc_info:
+ WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
+ assert "super_secret_key_12345" not in str(exc_info.value)
+
+ @patch("services.auth.watercrawl.watercrawl.requests.get")
+ def test_should_use_custom_base_url_in_validation(self, mock_get):
+ """Test that custom base URL is used in validation"""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_get.return_value = mock_response
+
+ credentials = {
+ "auth_type": "x-api-key",
+ "config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"},
+ }
+ auth = WatercrawlAuth(credentials)
+ result = auth.validate_credentials()
+
+ assert result is True
+ assert mock_get.call_args[0][0] == "https://custom.watercrawl.dev/api/v1/core/crawl-requests/"
+
+ @pytest.mark.parametrize(
+ ("base_url", "expected_url"),
+ [
+ ("https://app.watercrawl.dev", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
+ ("https://app.watercrawl.dev/", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
+ ("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
+ ],
+ )
+ @patch("services.auth.watercrawl.watercrawl.requests.get")
+ def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
+ """Test that urljoin is used correctly for URL construction with various base URLs"""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_get.return_value = mock_response
+
+ credentials = {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123", "base_url": base_url}}
+ auth = WatercrawlAuth(credentials)
+ auth.validate_credentials()
+
+ # Verify the correct URL was called
+ assert mock_get.call_args[0][0] == expected_url
+
+ @patch("services.auth.watercrawl.watercrawl.requests.get")
+ def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
+ """Test that timeout errors are handled gracefully with appropriate error message"""
+ mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds")
+
+ with pytest.raises(requests.Timeout) as exc_info:
+ auth_instance.validate_credentials()
+
+ # Verify the timeout exception is raised with original message
+ assert "timed out" in str(exc_info.value)
diff --git a/api/tests/unit_tests/services/services_test_help.py b/api/tests/unit_tests/services/services_test_help.py
new file mode 100644
index 0000000000..c6b962f7fc
--- /dev/null
+++ b/api/tests/unit_tests/services/services_test_help.py
@@ -0,0 +1,59 @@
+from unittest.mock import MagicMock
+
+
+class ServiceDbTestHelper:
+ """
+ Helper class for service database query tests.
+ """
+
+ @staticmethod
+ def setup_db_query_filter_by_mock(mock_db, query_results):
+ """
+ Smart database query mock that responds based on model type and query parameters.
+
+ Args:
+ mock_db: Mock database session
+ query_results: Dict mapping (model_name, filter_key, filter_value) to return value
+ Example: {('Account', 'email', 'test@example.com'): mock_account}
+ """
+
+ def query_side_effect(model):
+ mock_query = MagicMock()
+
+ def filter_by_side_effect(**kwargs):
+ mock_filter_result = MagicMock()
+
+ def first_side_effect():
+ # Find matching result based on model and filter parameters
+ for (model_name, filter_key, filter_value), result in query_results.items():
+ if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
+ return result
+ return None
+
+ mock_filter_result.first.side_effect = first_side_effect
+
+ # Handle order_by calls for complex queries
+ def order_by_side_effect(*args, **kwargs):
+ mock_order_result = MagicMock()
+
+ def order_first_side_effect():
+ # Look for order_by results in the same query_results dict
+ for (model_name, filter_key, filter_value), result in query_results.items():
+ if (
+ model.__name__ == model_name
+ and filter_key == "order_by"
+ and filter_value == "first_available"
+ ):
+ return result
+ return None
+
+ mock_order_result.first.side_effect = order_first_side_effect
+ return mock_order_result
+
+ mock_filter_result.order_by.side_effect = order_by_side_effect
+ return mock_filter_result
+
+ mock_query.filter_by.side_effect = filter_by_side_effect
+ return mock_query
+
+ mock_db.session.query.side_effect = query_side_effect
diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py
new file mode 100644
index 0000000000..442839e44e
--- /dev/null
+++ b/api/tests/unit_tests/services/test_account_service.py
@@ -0,0 +1,1545 @@
+import json
+from datetime import datetime, timedelta
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from configs import dify_config
+from models.account import Account
+from services.account_service import AccountService, RegisterService, TenantService
+from services.errors.account import (
+ AccountAlreadyInTenantError,
+ AccountLoginError,
+ AccountNotFoundError,
+ AccountPasswordError,
+ AccountRegisterError,
+ CurrentPasswordIncorrectError,
+)
+from tests.unit_tests.services.services_test_help import ServiceDbTestHelper
+
+
+class TestAccountAssociatedDataFactory:
+ """Factory class for creating test data and mock objects for account service tests."""
+
+ @staticmethod
+ def create_account_mock(
+ account_id: str = "user-123",
+ email: str = "test@example.com",
+ name: str = "Test User",
+ status: str = "active",
+ password: str = "hashed_password",
+ password_salt: str = "salt",
+ interface_language: str = "en-US",
+ interface_theme: str = "light",
+ timezone: str = "UTC",
+ **kwargs,
+ ) -> MagicMock:
+ """Create a mock account with specified attributes."""
+ account = MagicMock(spec=Account)
+ account.id = account_id
+ account.email = email
+ account.name = name
+ account.status = status
+ account.password = password
+ account.password_salt = password_salt
+ account.interface_language = interface_language
+ account.interface_theme = interface_theme
+ account.timezone = timezone
+ # Set last_active_at to a datetime object that's older than 10 minutes
+ account.last_active_at = datetime.now() - timedelta(minutes=15)
+ account.initialized_at = None
+ for key, value in kwargs.items():
+ setattr(account, key, value)
+ return account
+
+ @staticmethod
+ def create_tenant_join_mock(
+ tenant_id: str = "tenant-456",
+ account_id: str = "user-123",
+ current: bool = True,
+ role: str = "normal",
+ **kwargs,
+ ) -> MagicMock:
+ """Create a mock tenant account join record."""
+ tenant_join = MagicMock()
+ tenant_join.tenant_id = tenant_id
+ tenant_join.account_id = account_id
+ tenant_join.current = current
+ tenant_join.role = role
+ for key, value in kwargs.items():
+ setattr(tenant_join, key, value)
+ return tenant_join
+
+ @staticmethod
+ def create_feature_service_mock(allow_register: bool = True):
+ """Create a mock feature service."""
+ mock_service = MagicMock()
+ mock_service.get_system_features.return_value.is_allow_register = allow_register
+ return mock_service
+
+ @staticmethod
+ def create_billing_service_mock(email_frozen: bool = False):
+ """Create a mock billing service."""
+ mock_service = MagicMock()
+ mock_service.is_email_in_freeze.return_value = email_frozen
+ return mock_service
+
+
+class TestAccountService:
+ """
+ Comprehensive unit tests for AccountService methods.
+
+ This test suite covers all account-related operations including:
+ - Authentication and login
+ - Account creation and registration
+ - Password management
+ - JWT token generation
+ - User loading and tenant management
+ - Error conditions and edge cases
+ """
+
+ @pytest.fixture
+ def mock_db_dependencies(self):
+ """Common mock setup for database dependencies."""
+ with patch("services.account_service.db") as mock_db:
+ mock_db.session.add = MagicMock()
+ mock_db.session.commit = MagicMock()
+ yield {
+ "db": mock_db,
+ }
+
+ @pytest.fixture
+ def mock_password_dependencies(self):
+ """Mock setup for password-related functions."""
+ with (
+ patch("services.account_service.compare_password") as mock_compare_password,
+ patch("services.account_service.hash_password") as mock_hash_password,
+ patch("services.account_service.valid_password") as mock_valid_password,
+ ):
+ yield {
+ "compare_password": mock_compare_password,
+ "hash_password": mock_hash_password,
+ "valid_password": mock_valid_password,
+ }
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.account_service.FeatureService") as mock_feature_service,
+ patch("services.account_service.BillingService") as mock_billing_service,
+ patch("services.account_service.PassportService") as mock_passport_service,
+ ):
+ yield {
+ "feature_service": mock_feature_service,
+ "billing_service": mock_billing_service,
+ "passport_service": mock_passport_service,
+ }
+
+ @pytest.fixture
+ def mock_db_with_autospec(self):
+ """
+ Mock database with autospec for more realistic behavior.
+ This approach preserves the actual method signatures and behavior.
+ """
+ with patch("services.account_service.db", autospec=True) as mock_db:
+ # Create a more realistic session mock
+ mock_session = MagicMock()
+ mock_db.session = mock_session
+
+ # Setup basic session methods
+ mock_session.add = MagicMock()
+ mock_session.commit = MagicMock()
+ mock_session.query = MagicMock()
+
+ yield mock_db
+
+ def _assert_database_operations_called(self, mock_db):
+ """Helper method to verify database operations were called."""
+ mock_db.session.commit.assert_called()
+
+ def _assert_database_operations_not_called(self, mock_db):
+ """Helper method to verify database operations were not called."""
+ mock_db.session.commit.assert_not_called()
+
+ def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs):
+ """Helper method to verify that specific exception is raised."""
+ with pytest.raises(exception_type):
+ callable_func(*args, **kwargs)
+
+ # ==================== Authentication Tests ====================
+
+ def test_authenticate_success(self, mock_db_dependencies, mock_password_dependencies):
+ """Test successful authentication with correct email and password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock
+ query_results = {("Account", "email", "test@example.com"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ mock_password_dependencies["compare_password"].return_value = True
+
+ # Execute test
+ result = AccountService.authenticate("test@example.com", "password")
+
+ # Verify results
+ assert result == mock_account
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_authenticate_account_not_found(self, mock_db_dependencies):
+ """Test authentication when account does not exist."""
+ # Setup smart database query mock - no matching results
+ query_results = {("Account", "email", "notfound@example.com"): None}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password"
+ )
+
+ def test_authenticate_account_banned(self, mock_db_dependencies):
+ """Test authentication when account is banned."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
+
+ # Setup smart database query mock
+ query_results = {("Account", "email", "banned@example.com"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test and verify exception
+ self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password")
+
+ def test_authenticate_password_error(self, mock_db_dependencies, mock_password_dependencies):
+ """Test authentication with wrong password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock
+ query_results = {("Account", "email", "test@example.com"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ mock_password_dependencies["compare_password"].return_value = False
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountPasswordError, AccountService.authenticate, "test@example.com", "wrongpassword"
+ )
+
+ def test_authenticate_pending_account_activates(self, mock_db_dependencies, mock_password_dependencies):
+ """Test authentication for a pending account, which should activate on login."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending")
+
+ # Setup smart database query mock
+ query_results = {("Account", "email", "pending@example.com"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ mock_password_dependencies["compare_password"].return_value = True
+
+ # Execute test
+ result = AccountService.authenticate("pending@example.com", "password")
+
+ # Verify results
+ assert result == mock_account
+ assert mock_account.status == "active"
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ # ==================== Account Creation Tests ====================
+
+ def test_create_account_success(
+ self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies
+ ):
+ """Test successful account creation with all required parameters."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+ mock_password_dependencies["hash_password"].return_value = b"hashed_password"
+
+ # Execute test
+ result = AccountService.create_account(
+ email="test@example.com",
+ name="Test User",
+ interface_language="en-US",
+ password="password123",
+ interface_theme="light",
+ )
+
+ # Verify results
+ assert result.email == "test@example.com"
+ assert result.name == "Test User"
+ assert result.interface_language == "en-US"
+ assert result.interface_theme == "light"
+ assert result.password is not None
+ assert result.password_salt is not None
+ assert result.timezone is not None
+
+ # Verify database operations
+ mock_db_dependencies["db"].session.add.assert_called_once()
+ added_account = mock_db_dependencies["db"].session.add.call_args[0][0]
+ assert added_account.email == "test@example.com"
+ assert added_account.name == "Test User"
+ assert added_account.interface_language == "en-US"
+ assert added_account.interface_theme == "light"
+ assert added_account.password is not None
+ assert added_account.password_salt is not None
+ assert added_account.timezone is not None
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_create_account_registration_disabled(self, mock_external_service_dependencies):
+ """Test account creation when registration is disabled."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = False
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ Exception, # AccountNotFound
+ AccountService.create_account,
+ email="test@example.com",
+ name="Test User",
+ interface_language="en-US",
+ )
+
+ def test_create_account_email_frozen(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test account creation with frozen email address."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True
+ dify_config.BILLING_ENABLED = True
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountRegisterError,
+ AccountService.create_account,
+ email="frozen@example.com",
+ name="Test User",
+ interface_language="en-US",
+ )
+ dify_config.BILLING_ENABLED = False
+
+ def test_create_account_without_password(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test account creation without password (for invite-based registration)."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Execute test
+ result = AccountService.create_account(
+ email="test@example.com",
+ name="Test User",
+ interface_language="zh-CN",
+ password=None,
+ interface_theme="dark",
+ )
+
+ # Verify results
+ assert result.email == "test@example.com"
+ assert result.name == "Test User"
+ assert result.interface_language == "zh-CN"
+ assert result.interface_theme == "dark"
+ assert result.password is None
+ assert result.password_salt is None
+ assert result.timezone is not None
+
+ # Verify database operations
+ mock_db_dependencies["db"].session.add.assert_called_once()
+ added_account = mock_db_dependencies["db"].session.add.call_args[0][0]
+ assert added_account.email == "test@example.com"
+ assert added_account.name == "Test User"
+ assert added_account.interface_language == "zh-CN"
+ assert added_account.interface_theme == "dark"
+ assert added_account.password is None
+ assert added_account.password_salt is None
+ assert added_account.timezone is not None
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ # ==================== Password Management Tests ====================
+
+ def test_update_account_password_success(self, mock_db_dependencies, mock_password_dependencies):
+ """Test successful password update with correct current password and valid new password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_password_dependencies["compare_password"].return_value = True
+ mock_password_dependencies["valid_password"].return_value = None
+ mock_password_dependencies["hash_password"].return_value = b"new_hashed_password"
+
+ # Execute test
+ result = AccountService.update_account_password(mock_account, "old_password", "new_password123")
+
+ # Verify results
+ assert result == mock_account
+ assert mock_account.password is not None
+ assert mock_account.password_salt is not None
+
+ # Verify password validation was called
+ mock_password_dependencies["compare_password"].assert_called_once_with(
+ "old_password", "hashed_password", "salt"
+ )
+ mock_password_dependencies["valid_password"].assert_called_once_with("new_password123")
+
+ # Verify database operations
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_update_account_password_current_password_incorrect(self, mock_password_dependencies):
+ """Test password update with incorrect current password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_password_dependencies["compare_password"].return_value = False
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ CurrentPasswordIncorrectError,
+ AccountService.update_account_password,
+ mock_account,
+ "wrong_password",
+ "new_password123",
+ )
+
+ # Verify password comparison was called
+ mock_password_dependencies["compare_password"].assert_called_once_with(
+ "wrong_password", "hashed_password", "salt"
+ )
+
+ def test_update_account_password_invalid_new_password(self, mock_password_dependencies):
+ """Test password update with invalid new password."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_password_dependencies["compare_password"].return_value = True
+ mock_password_dependencies["valid_password"].side_effect = ValueError("Password too short")
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ ValueError, AccountService.update_account_password, mock_account, "old_password", "short"
+ )
+
+ # Verify password validation was called
+ mock_password_dependencies["valid_password"].assert_called_once_with("short")
+
+ # ==================== User Loading Tests ====================
+
+ def test_load_user_success(self, mock_db_dependencies):
+ """Test successful user loading with current tenant."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock()
+
+ # Setup smart database query mock
+ query_results = {
+ ("Account", "id", "user-123"): mock_account,
+ ("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join,
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock datetime
+ with patch("services.account_service.datetime") as mock_datetime:
+ mock_now = datetime.now()
+ mock_datetime.now.return_value = mock_now
+ mock_datetime.UTC = "UTC"
+
+ # Execute test
+ result = AccountService.load_user("user-123")
+
+ # Verify results
+ assert result == mock_account
+ assert mock_account.set_tenant_id.called
+
+ def test_load_user_not_found(self, mock_db_dependencies):
+ """Test user loading when user does not exist."""
+ # Setup smart database query mock - no matching results
+ query_results = {("Account", "id", "non-existent-user"): None}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test
+ result = AccountService.load_user("non-existent-user")
+
+ # Verify results
+ assert result is None
+
+ def test_load_user_banned(self, mock_db_dependencies):
+ """Test user loading when user is banned."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned")
+
+ # Setup smart database query mock
+ query_results = {("Account", "id", "user-123"): mock_account}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ Exception, # Unauthorized
+ AccountService.load_user,
+ "user-123",
+ )
+
+ def test_load_user_no_current_tenant(self, mock_db_dependencies):
+ """Test user loading when user has no current tenant but has available tenants."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False)
+
+ # Setup smart database query mock for complex scenario
+ query_results = {
+ ("Account", "id", "user-123"): mock_account,
+ ("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
+ ("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock datetime
+ with patch("services.account_service.datetime") as mock_datetime:
+ mock_now = datetime.now()
+ mock_datetime.now.return_value = mock_now
+ mock_datetime.UTC = "UTC"
+
+ # Execute test
+ result = AccountService.load_user("user-123")
+
+ # Verify results
+ assert result == mock_account
+ assert mock_available_tenant.current is True
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_load_user_no_tenants(self, mock_db_dependencies):
+ """Test user loading when user has no tenants at all."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock for no tenants scenario
+ query_results = {
+ ("Account", "id", "user-123"): mock_account,
+ ("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant
+ ("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock datetime
+ with patch("services.account_service.datetime") as mock_datetime:
+ mock_now = datetime.now()
+ mock_datetime.now.return_value = mock_now
+ mock_datetime.UTC = "UTC"
+
+ # Execute test
+ result = AccountService.load_user("user-123")
+
+ # Verify results
+ assert result is None
+
+
+class TestTenantService:
+ """
+ Comprehensive unit tests for TenantService methods.
+
+ This test suite covers all tenant-related operations including:
+ - Tenant creation and management
+ - Member management and permissions
+ - Tenant switching
+ - Role updates and permission checks
+ - Error conditions and edge cases
+ """
+
+ @pytest.fixture
+ def mock_db_dependencies(self):
+ """Common mock setup for database dependencies."""
+ with patch("services.account_service.db") as mock_db:
+ mock_db.session.add = MagicMock()
+ mock_db.session.commit = MagicMock()
+ yield {
+ "db": mock_db,
+ }
+
+ @pytest.fixture
+ def mock_rsa_dependencies(self):
+ """Mock setup for RSA-related functions."""
+ with patch("services.account_service.generate_key_pair") as mock_generate_key_pair:
+ yield mock_generate_key_pair
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.account_service.FeatureService") as mock_feature_service,
+ patch("services.account_service.BillingService") as mock_billing_service,
+ ):
+ yield {
+ "feature_service": mock_feature_service,
+ "billing_service": mock_billing_service,
+ }
+
+ def _assert_database_operations_called(self, mock_db):
+ """Helper method to verify database operations were called."""
+ mock_db.session.commit.assert_called()
+
+ def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs):
+ """Helper method to verify that specific exception is raised."""
+ with pytest.raises(exception_type):
+ callable_func(*args, **kwargs)
+
+ # ==================== Tenant Creation Tests ====================
+
+ def test_create_owner_tenant_if_not_exist_new_user(
+ self, mock_db_dependencies, mock_rsa_dependencies, mock_external_service_dependencies
+ ):
+ """Test creating owner tenant for new user without existing tenants."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock - no existing tenant joins
+ query_results = {
+ ("TenantAccountJoin", "account_id", "user-123"): None,
+ ("TenantAccountJoin", "tenant_id", "tenant-456"): None, # For has_roles check
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Setup external service mocks
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+
+ # Mock tenant creation
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.name = "Test User's Workspace"
+
+ # Mock database operations
+ mock_db_dependencies["db"].session.add = MagicMock()
+
+ # Mock RSA key generation
+ mock_rsa_dependencies.return_value = "mock_public_key"
+
+ # Mock has_roles method to return False (no existing owner)
+ with patch("services.account_service.TenantService.has_roles") as mock_has_roles:
+ mock_has_roles.return_value = False
+
+ # Mock Tenant creation to set proper ID
+ with patch("services.account_service.Tenant") as mock_tenant_class:
+ mock_tenant_instance = MagicMock()
+ mock_tenant_instance.id = "tenant-456"
+ mock_tenant_instance.name = "Test User's Workspace"
+ mock_tenant_class.return_value = mock_tenant_instance
+
+ # Execute test
+ TenantService.create_owner_tenant_if_not_exist(mock_account)
+
+ # Verify tenant was created with correct parameters
+ mock_db_dependencies["db"].session.add.assert_called()
+
+ # Get all calls to session.add
+ add_calls = mock_db_dependencies["db"].session.add.call_args_list
+
+ # Should have at least 2 calls: one for Tenant, one for TenantAccountJoin
+ assert len(add_calls) >= 2
+
+ # Verify Tenant was added with correct name
+ tenant_added = False
+ tenant_account_join_added = False
+
+ for call in add_calls:
+ added_object = call[0][0] # First argument of the call
+
+ # Check if it's a Tenant object
+ if hasattr(added_object, "name") and hasattr(added_object, "id"):
+ # This should be a Tenant object
+ assert added_object.name == "Test User's Workspace"
+ tenant_added = True
+
+ # Check if it's a TenantAccountJoin object
+ elif (
+ hasattr(added_object, "tenant_id")
+ and hasattr(added_object, "account_id")
+ and hasattr(added_object, "role")
+ ):
+ # This should be a TenantAccountJoin object
+ assert added_object.tenant_id is not None
+ assert added_object.account_id == "user-123"
+ assert added_object.role == "owner"
+ tenant_account_join_added = True
+
+ assert tenant_added, "Tenant object was not added to database"
+ assert tenant_account_join_added, "TenantAccountJoin object was not added to database"
+
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+ assert mock_rsa_dependencies.called, "RSA key generation was not called"
+
+ # ==================== Member Management Tests ====================
+
+ def test_create_tenant_member_success(self, mock_db_dependencies):
+ """Test successful tenant member creation."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Setup smart database query mock - no existing member
+ query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): None}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock database operations
+ mock_db_dependencies["db"].session.add = MagicMock()
+
+ # Execute test
+ result = TenantService.create_tenant_member(mock_tenant, mock_account, "normal")
+
+ # Verify member was created with correct parameters
+ assert result is not None
+ mock_db_dependencies["db"].session.add.assert_called_once()
+
+ # Verify the TenantAccountJoin object was added with correct parameters
+ added_tenant_account_join = mock_db_dependencies["db"].session.add.call_args[0][0]
+ assert added_tenant_account_join.tenant_id == "tenant-456"
+ assert added_tenant_account_join.account_id == "user-123"
+ assert added_tenant_account_join.role == "normal"
+
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ # ==================== Tenant Switching Tests ====================
+
+ def test_switch_tenant_success(self):
+ """Test successful tenant switching."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock(
+ tenant_id="tenant-456", account_id="user-123", current=False
+ )
+
+ # Mock the complex query in switch_tenant method
+ with patch("services.account_service.db") as mock_db:
+ # Mock the join query that returns the tenant_account_join
+ mock_query = MagicMock()
+ mock_where = MagicMock()
+ mock_where.first.return_value = mock_tenant_join
+ mock_query.where.return_value = mock_where
+ mock_query.join.return_value = mock_query
+ mock_db.session.query.return_value = mock_query
+
+ # Execute test
+ TenantService.switch_tenant(mock_account, "tenant-456")
+
+ # Verify tenant was switched
+ assert mock_tenant_join.current is True
+ self._assert_database_operations_called(mock_db)
+
+ def test_switch_tenant_no_tenant_id(self):
+ """Test tenant switching without providing tenant ID."""
+ # Setup test data
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+
+ # Execute test and verify exception
+ self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None)
+
+ # ==================== Role Management Tests ====================
+
+ def test_update_member_role_success(self):
+ """Test successful member role update."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_member = TestAccountAssociatedDataFactory.create_account_mock(account_id="member-789")
+ mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123")
+ mock_target_join = TestAccountAssociatedDataFactory.create_tenant_join_mock(
+ tenant_id="tenant-456", account_id="member-789", role="normal"
+ )
+ mock_operator_join = TestAccountAssociatedDataFactory.create_tenant_join_mock(
+ tenant_id="tenant-456", account_id="operator-123", role="owner"
+ )
+
+ # Mock the database queries in update_member_role method
+ with patch("services.account_service.db") as mock_db:
+ # Mock the first query for operator permission check
+ mock_query1 = MagicMock()
+ mock_filter1 = MagicMock()
+ mock_filter1.first.return_value = mock_operator_join
+ mock_query1.filter_by.return_value = mock_filter1
+
+ # Mock the second query for target member
+ mock_query2 = MagicMock()
+ mock_filter2 = MagicMock()
+ mock_filter2.first.return_value = mock_target_join
+ mock_query2.filter_by.return_value = mock_filter2
+
+ # Make the query method return different mocks for different calls
+ mock_db.session.query.side_effect = [mock_query1, mock_query2]
+
+ # Execute test
+ TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator)
+
+ # Verify role was updated
+ assert mock_target_join.role == "admin"
+ self._assert_database_operations_called(mock_db)
+
+ # ==================== Permission Check Tests ====================
+
+ def test_check_member_permission_success(self, mock_db_dependencies):
+ """Test successful member permission check."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123")
+ mock_member = TestAccountAssociatedDataFactory.create_account_mock(account_id="member-789")
+ mock_operator_join = TestAccountAssociatedDataFactory.create_tenant_join_mock(
+ tenant_id="tenant-456", account_id="operator-123", role="owner"
+ )
+
+ # Setup smart database query mock
+ query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): mock_operator_join}
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Execute test - should not raise exception
+ TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add")
+
+ def test_check_member_permission_operate_self(self):
+ """Test member permission check when operator tries to operate self."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123")
+
+ # Execute test and verify exception
+ from services.errors.account import CannotOperateSelfError
+
+ self._assert_exception_raised(
+ CannotOperateSelfError,
+ TenantService.check_member_permission,
+ mock_tenant,
+ mock_operator,
+ mock_operator, # Same as operator
+ "add",
+ )
+
+
+class TestRegisterService:
+ """
+ Comprehensive unit tests for RegisterService methods.
+
+ This test suite covers all registration-related operations including:
+ - System setup
+ - Account registration
+ - Member invitation
+ - Token management
+ - Invitation validation
+ - Error conditions and edge cases
+ """
+
+ @pytest.fixture
+ def mock_db_dependencies(self):
+ """Common mock setup for database dependencies."""
+ with patch("services.account_service.db") as mock_db:
+ mock_db.session.add = MagicMock()
+ mock_db.session.commit = MagicMock()
+ mock_db.session.begin_nested = MagicMock()
+ mock_db.session.rollback = MagicMock()
+ yield {
+ "db": mock_db,
+ }
+
+ @pytest.fixture
+ def mock_redis_dependencies(self):
+ """Mock setup for Redis-related functions."""
+ with patch("services.account_service.redis_client") as mock_redis:
+ yield mock_redis
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.account_service.FeatureService") as mock_feature_service,
+ patch("services.account_service.BillingService") as mock_billing_service,
+ patch("services.account_service.PassportService") as mock_passport_service,
+ ):
+ yield {
+ "feature_service": mock_feature_service,
+ "billing_service": mock_billing_service,
+ "passport_service": mock_passport_service,
+ }
+
+ @pytest.fixture
+ def mock_task_dependencies(self):
+ """Mock setup for task dependencies."""
+ with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail:
+ yield mock_send_mail
+
+ def _assert_database_operations_called(self, mock_db):
+ """Helper method to verify database operations were called."""
+ mock_db.session.commit.assert_called()
+
+ def _assert_database_operations_not_called(self, mock_db):
+ """Helper method to verify database operations were not called."""
+ mock_db.session.commit.assert_not_called()
+
+ def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs):
+ """Helper method to verify that specific exception is raised."""
+ with pytest.raises(exception_type):
+ callable_func(*args, **kwargs)
+
+ # ==================== Setup Tests ====================
+
+ def test_setup_success(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test successful system setup."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.return_value = mock_account
+
+ # Mock TenantService.create_owner_tenant_if_not_exist
+ with patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_tenant:
+ # Mock DifySetup
+ with patch("services.account_service.DifySetup") as mock_dify_setup:
+ mock_dify_setup_instance = MagicMock()
+ mock_dify_setup.return_value = mock_dify_setup_instance
+
+ # Execute test
+ RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1")
+
+ # Verify results
+ mock_create_account.assert_called_once_with(
+ email="admin@example.com",
+ name="Admin User",
+ interface_language="en-US",
+ password="password123",
+ is_setup=True,
+ )
+ mock_create_tenant.assert_called_once_with(account=mock_account, is_setup=True)
+ mock_dify_setup.assert_called_once()
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_setup_failure_rollback(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test setup failure with proper rollback."""
+ # Setup mocks to simulate failure
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account to raise exception
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.side_effect = Exception("Database error")
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ ValueError,
+ RegisterService.setup,
+ "admin@example.com",
+ "Admin User",
+ "password123",
+ "192.168.1.1",
+ )
+
+ # Verify rollback operations were called
+ mock_db_dependencies["db"].session.query.assert_called()
+
+ # ==================== Registration Tests ====================
+
+ def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test successful account registration."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.return_value = mock_account
+
+ # Mock TenantService.create_tenant and create_tenant_member
+ with (
+ patch("services.account_service.TenantService.create_tenant") as mock_create_tenant,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.tenant_was_created") as mock_event,
+ ):
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_create_tenant.return_value = mock_tenant
+
+ # Execute test
+ result = RegisterService.register(
+ email="test@example.com",
+ name="Test User",
+ password="password123",
+ language="en-US",
+ )
+
+ # Verify results
+ assert result == mock_account
+ assert result.status == "active"
+ assert result.initialized_at is not None
+ mock_create_account.assert_called_once_with(
+ email="test@example.com",
+ name="Test User",
+ interface_language="en-US",
+ password="password123",
+ is_setup=False,
+ )
+ mock_create_tenant.assert_called_once_with("Test User's Workspace")
+ mock_create_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
+ mock_event.send.assert_called_once_with(mock_tenant)
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test account registration with OAuth integration."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account and link_account_integrate
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with (
+ patch("services.account_service.AccountService.create_account") as mock_create_account,
+ patch("services.account_service.AccountService.link_account_integrate") as mock_link_account,
+ ):
+ mock_create_account.return_value = mock_account
+
+ # Mock TenantService methods
+ with (
+ patch("services.account_service.TenantService.create_tenant") as mock_create_tenant,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.tenant_was_created") as mock_event,
+ ):
+ mock_tenant = MagicMock()
+ mock_create_tenant.return_value = mock_tenant
+
+ # Execute test
+ result = RegisterService.register(
+ email="test@example.com",
+ name="Test User",
+ password=None,
+ open_id="oauth123",
+ provider="google",
+ language="en-US",
+ )
+
+ # Verify results
+ assert result == mock_account
+ mock_link_account.assert_called_once_with("google", "oauth123", mock_account)
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_register_with_pending_status(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test account registration with pending status."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.return_value = mock_account
+
+ # Mock TenantService methods
+ with (
+ patch("services.account_service.TenantService.create_tenant") as mock_create_tenant,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.tenant_was_created") as mock_event,
+ ):
+ mock_tenant = MagicMock()
+ mock_create_tenant.return_value = mock_tenant
+
+ # Execute test with pending status
+ from models.account import AccountStatus
+
+ result = RegisterService.register(
+ email="test@example.com",
+ name="Test User",
+ password="password123",
+ language="en-US",
+ status=AccountStatus.PENDING,
+ )
+
+ # Verify results
+ assert result == mock_account
+ assert result.status == "pending"
+ self._assert_database_operations_called(mock_db_dependencies["db"])
+
+ def test_register_workspace_not_allowed(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test registration when workspace creation is not allowed."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.is_allow_create_workspace = True
+ mock_external_service_dependencies[
+ "feature_service"
+ ].get_system_features.return_value.license.workspaces.is_available.return_value = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock()
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.return_value = mock_account
+
+ # Execute test and verify exception
+ from services.errors.workspace import WorkSpaceNotAllowedCreateError
+
+ with patch("services.account_service.TenantService.create_tenant") as mock_create_tenant:
+ mock_create_tenant.side_effect = WorkSpaceNotAllowedCreateError()
+
+ self._assert_exception_raised(
+ AccountRegisterError,
+ RegisterService.register,
+ email="test@example.com",
+ name="Test User",
+ password="password123",
+ language="en-US",
+ )
+
+ # Verify rollback was called
+ mock_db_dependencies["db"].session.rollback.assert_called()
+
+ def test_register_general_exception(self, mock_db_dependencies, mock_external_service_dependencies):
+ """Test registration with general exception handling."""
+ # Setup mocks
+ mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
+ mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
+
+ # Mock AccountService.create_account to raise exception
+ with patch("services.account_service.AccountService.create_account") as mock_create_account:
+ mock_create_account.side_effect = Exception("Unexpected error")
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountRegisterError,
+ RegisterService.register,
+ email="test@example.com",
+ name="Test User",
+ password="password123",
+ language="en-US",
+ )
+
+ # Verify rollback was called
+ mock_db_dependencies["db"].session.rollback.assert_called()
+
+ # ==================== Member Invitation Tests ====================
+
+ def test_invite_new_member_new_account(self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies):
+ """Test inviting a new member who doesn't have an account."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.name = "Test Workspace"
+ mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
+
+ # Mock database queries - need to mock the Session query
+ mock_session = MagicMock()
+ mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
+
+ with patch("services.account_service.Session") as mock_session_class:
+ mock_session_class.return_value.__enter__.return_value = mock_session
+ mock_session_class.return_value.__exit__.return_value = None
+
+ # Mock RegisterService.register
+ mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="new-user-456", email="newuser@example.com", name="newuser", status="pending"
+ )
+ with patch("services.account_service.RegisterService.register") as mock_register:
+ mock_register.return_value = mock_new_account
+
+ # Mock TenantService methods
+ with (
+ patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant,
+ patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
+ ):
+ mock_generate_token.return_value = "invite-token-123"
+
+ # Execute test
+ result = RegisterService.invite_new_member(
+ tenant=mock_tenant,
+ email="newuser@example.com",
+ language="en-US",
+ role="normal",
+ inviter=mock_inviter,
+ )
+
+ # Verify results
+ assert result == "invite-token-123"
+ mock_register.assert_called_once_with(
+ email="newuser@example.com",
+ name="newuser",
+ language="en-US",
+ status="pending",
+ is_setup=True,
+ )
+ mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal")
+ mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id)
+ mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account)
+ mock_task_dependencies.delay.assert_called_once()
+
+ def test_invite_new_member_existing_account(
+ self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies
+ ):
+ """Test inviting a new member who already has an account."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.name = "Test Workspace"
+ mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
+ mock_existing_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="existing-user-456", email="existing@example.com", status="pending"
+ )
+
+ # Mock database queries - need to mock the Session query
+ mock_session = MagicMock()
+ mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
+
+ with patch("services.account_service.Session") as mock_session_class:
+ mock_session_class.return_value.__enter__.return_value = mock_session
+ mock_session_class.return_value.__exit__.return_value = None
+
+ # Mock the db.session.query for TenantAccountJoin
+ mock_db_query = MagicMock()
+ mock_db_query.filter_by.return_value.first.return_value = None # No existing member
+ mock_db_dependencies["db"].session.query.return_value = mock_db_query
+
+ # Mock TenantService methods
+ with (
+ patch("services.account_service.TenantService.check_member_permission") as mock_check_permission,
+ patch("services.account_service.TenantService.create_tenant_member") as mock_create_member,
+ patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token,
+ ):
+ mock_generate_token.return_value = "invite-token-123"
+
+ # Execute test
+ result = RegisterService.invite_new_member(
+ tenant=mock_tenant,
+ email="existing@example.com",
+ language="en-US",
+ role="normal",
+ inviter=mock_inviter,
+ )
+
+ # Verify results
+ assert result == "invite-token-123"
+ mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal")
+ mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account)
+ mock_task_dependencies.delay.assert_called_once()
+
+ def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test inviting a member who is already in the tenant."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
+ mock_existing_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="existing-user-456", email="existing@example.com", status="active"
+ )
+
+ # Mock database queries
+ query_results = {
+ ("Account", "email", "existing@example.com"): mock_existing_account,
+ (
+ "TenantAccountJoin",
+ "tenant_id",
+ "tenant-456",
+ ): TestAccountAssociatedDataFactory.create_tenant_join_mock(),
+ }
+ ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results)
+
+ # Mock TenantService methods
+ with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission:
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ AccountAlreadyInTenantError,
+ RegisterService.invite_new_member,
+ tenant=mock_tenant,
+ email="existing@example.com",
+ language="en-US",
+ role="normal",
+ inviter=mock_inviter,
+ )
+
+ def test_invite_new_member_no_inviter(self):
+ """Test inviting a member without providing an inviter."""
+ # Setup test data
+ mock_tenant = MagicMock()
+
+ # Execute test and verify exception
+ self._assert_exception_raised(
+ ValueError,
+ RegisterService.invite_new_member,
+ tenant=mock_tenant,
+ email="test@example.com",
+ language="en-US",
+ role="normal",
+ inviter=None,
+ )
+
+ # ==================== Token Management Tests ====================
+
+ def test_generate_invite_token_success(self, mock_redis_dependencies):
+ """Test successful invite token generation."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="user-123", email="test@example.com"
+ )
+
+ # Mock uuid generation
+ with patch("services.account_service.uuid.uuid4") as mock_uuid:
+ mock_uuid.return_value = "test-uuid-123"
+
+ # Execute test
+ result = RegisterService.generate_invite_token(mock_tenant, mock_account)
+
+ # Verify results
+ assert result == "test-uuid-123"
+ mock_redis_dependencies.setex.assert_called_once()
+
+ # Verify the stored data
+ call_args = mock_redis_dependencies.setex.call_args
+ assert call_args[0][0] == "member_invite:token:test-uuid-123"
+ stored_data = json.loads(call_args[0][2])
+ assert stored_data["account_id"] == "user-123"
+ assert stored_data["email"] == "test@example.com"
+ assert stored_data["workspace_id"] == "tenant-456"
+
+ def test_is_valid_invite_token_valid(self, mock_redis_dependencies):
+ """Test checking valid invite token."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = b'{"test": "data"}'
+
+ # Execute test
+ result = RegisterService.is_valid_invite_token("valid-token")
+
+ # Verify results
+ assert result is True
+ mock_redis_dependencies.get.assert_called_once_with("member_invite:token:valid-token")
+
+ def test_is_valid_invite_token_invalid(self, mock_redis_dependencies):
+ """Test checking invalid invite token."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = None
+
+ # Execute test
+ result = RegisterService.is_valid_invite_token("invalid-token")
+
+ # Verify results
+ assert result is False
+ mock_redis_dependencies.get.assert_called_once_with("member_invite:token:invalid-token")
+
+ def test_revoke_token_with_workspace_and_email(self, mock_redis_dependencies):
+ """Test revoking token with workspace ID and email."""
+ # Execute test
+ RegisterService.revoke_token("workspace-123", "test@example.com", "token-123")
+
+ # Verify results
+ mock_redis_dependencies.delete.assert_called_once()
+ call_args = mock_redis_dependencies.delete.call_args
+ assert "workspace-123" in call_args[0][0]
+ # The email is hashed, so we check for the hash pattern instead
+ assert "member_invite_token:" in call_args[0][0]
+
+ def test_revoke_token_without_workspace_and_email(self, mock_redis_dependencies):
+ """Test revoking token without workspace ID and email."""
+ # Execute test
+ RegisterService.revoke_token("", "", "token-123")
+
+ # Verify results
+ mock_redis_dependencies.delete.assert_called_once_with("member_invite:token:token-123")
+
+ # ==================== Invitation Validation Tests ====================
+
+ def test_get_invitation_if_token_valid_success(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test successful invitation validation."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.status = "normal"
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="user-123", email="test@example.com"
+ )
+
+ with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token:
+ # Mock the invitation data returned by _get_invitation_by_token
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_get_invitation_by_token.return_value = invitation_data
+
+ # Mock database queries - complex query mocking
+ mock_query1 = MagicMock()
+ mock_query1.where.return_value.first.return_value = mock_tenant
+
+ mock_query2 = MagicMock()
+ mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
+
+ mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is not None
+ assert result["account"] == mock_account
+ assert result["tenant"] == mock_tenant
+ assert result["data"] == invitation_data
+
+ def test_get_invitation_if_token_valid_no_token_data(self, mock_redis_dependencies):
+ """Test invitation validation with no token data."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = None
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is None
+
+ def test_get_invitation_if_token_valid_tenant_not_found(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test invitation validation when tenant is not found."""
+ # Setup mock Redis data
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
+
+ # Mock database queries - no tenant found
+ mock_query = MagicMock()
+ mock_query.filter.return_value.first.return_value = None
+ mock_db_dependencies["db"].session.query.return_value = mock_query
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is None
+
+ def test_get_invitation_if_token_valid_account_not_found(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test invitation validation when account is not found."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.status = "normal"
+
+ # Mock Redis data
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
+
+ # Mock database queries
+ mock_query1 = MagicMock()
+ mock_query1.filter.return_value.first.return_value = mock_tenant
+
+ mock_query2 = MagicMock()
+ mock_query2.join.return_value.where.return_value.first.return_value = None # No account found
+
+ mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is None
+
+ def test_get_invitation_if_token_valid_account_id_mismatch(self, mock_db_dependencies, mock_redis_dependencies):
+ """Test invitation validation when account ID doesn't match."""
+ # Setup test data
+ mock_tenant = MagicMock()
+ mock_tenant.id = "tenant-456"
+ mock_tenant.status = "normal"
+ mock_account = TestAccountAssociatedDataFactory.create_account_mock(
+ account_id="different-user-456", email="test@example.com"
+ )
+
+ # Mock Redis data with different account ID
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
+
+ # Mock database queries
+ mock_query1 = MagicMock()
+ mock_query1.filter.return_value.first.return_value = mock_tenant
+
+ mock_query2 = MagicMock()
+ mock_query2.join.return_value.where.return_value.first.return_value = (mock_account, "normal")
+
+ mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2]
+
+ # Execute test
+ result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123")
+
+ # Verify results
+ assert result is None
+
+ # ==================== Helper Method Tests ====================
+
+ def test_get_invitation_token_key(self):
+ """Test the _get_invitation_token_key helper method."""
+ # Execute test
+ result = RegisterService._get_invitation_token_key("test-token")
+
+ # Verify results
+ assert result == "member_invite:token:test-token"
+
+ def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies):
+ """Test _get_invitation_by_token with workspace ID and email."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = b"user-123"
+
+ # Execute test
+ result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com")
+
+ # Verify results
+ assert result is not None
+ assert result["account_id"] == "user-123"
+ assert result["email"] == "test@example.com"
+ assert result["workspace_id"] == "workspace-456"
+
+ def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies):
+ """Test _get_invitation_by_token without workspace ID and email."""
+ # Setup mock
+ invitation_data = {
+ "account_id": "user-123",
+ "email": "test@example.com",
+ "workspace_id": "tenant-456",
+ }
+ mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode()
+
+ # Execute test
+ result = RegisterService._get_invitation_by_token("token-123")
+
+ # Verify results
+ assert result is not None
+ assert result == invitation_data
+
+ def test_get_invitation_by_token_no_data(self, mock_redis_dependencies):
+ """Test _get_invitation_by_token with no data."""
+ # Setup mock
+ mock_redis_dependencies.get.return_value = None
+
+ # Execute test
+ result = RegisterService._get_invitation_by_token("token-123")
+
+ # Verify results
+ assert result is None
diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py
index cdbb439c85..7c40b1e556 100644
--- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py
+++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py
@@ -10,7 +10,6 @@ from core.model_runtime.entities.model_entities import ModelType
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
-from tests.unit_tests.conftest import redis_mock
class DatasetUpdateTestDataFactory:
@@ -103,17 +102,16 @@ class TestDatasetServiceUpdateDataset:
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("extensions.ext_database.db.session") as mock_db,
- patch("services.dataset_service.datetime") as mock_datetime,
+ patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
- mock_datetime.datetime.now.return_value = current_time
- mock_datetime.UTC = datetime.UTC
+ mock_naive_utc_now.return_value = current_time
yield {
"get_dataset": mock_get_dataset,
"check_permission": mock_check_perm,
"db_session": mock_db,
- "datetime": mock_datetime,
+ "naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
}
@@ -293,7 +291,7 @@ class TestDatasetServiceUpdateDataset:
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"updated_by": user.id,
- "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
+ "updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@@ -328,7 +326,7 @@ class TestDatasetServiceUpdateDataset:
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"updated_by": user.id,
- "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
+ "updated_at": mock_dataset_service_dependencies["current_time"],
}
actual_call_args = mock_dataset_service_dependencies[
@@ -366,7 +364,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": None,
"retrieval_model": "new_model",
"updated_by": user.id,
- "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
+ "updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@@ -423,7 +421,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": "binding-456",
"retrieval_model": "new_model",
"updated_by": user.id,
- "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
+ "updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@@ -464,7 +462,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
- "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
+ "updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@@ -526,7 +524,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": "binding-789",
"retrieval_model": "new_model",
"updated_by": user.id,
- "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
+ "updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
@@ -569,7 +567,7 @@ class TestDatasetServiceUpdateDataset:
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
- "updated_at": mock_dataset_service_dependencies["current_time"].replace(tzinfo=None),
+ "updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
diff --git a/api/tests/unit_tests/services/tools/__init__.py b/api/tests/unit_tests/services/tools/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/services/tools/test_tools_transform_service.py b/api/tests/unit_tests/services/tools/test_tools_transform_service.py
new file mode 100644
index 0000000000..549ad018e8
--- /dev/null
+++ b/api/tests/unit_tests/services/tools/test_tools_transform_service.py
@@ -0,0 +1,301 @@
+from unittest.mock import Mock
+
+from core.tools.__base.tool import Tool
+from core.tools.entities.api_entities import ToolApiEntity
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolParameter
+from services.tools.tools_transform_service import ToolTransformService
+
+
+class TestToolTransformService:
+ """Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
+
+ def test_convert_tool_with_parameter_override(self):
+ """Test that runtime parameters correctly override base parameters"""
+ # Create mock base parameters
+ base_param1 = Mock(spec=ToolParameter)
+ base_param1.name = "param1"
+ base_param1.form = ToolParameter.ToolParameterForm.FORM
+ base_param1.type = "string"
+ base_param1.label = "Base Param 1"
+
+ base_param2 = Mock(spec=ToolParameter)
+ base_param2.name = "param2"
+ base_param2.form = ToolParameter.ToolParameterForm.FORM
+ base_param2.type = "string"
+ base_param2.label = "Base Param 2"
+
+ # Create mock runtime parameters that override base parameters
+ runtime_param1 = Mock(spec=ToolParameter)
+ runtime_param1.name = "param1"
+ runtime_param1.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param1.type = "string"
+ runtime_param1.label = "Runtime Param 1" # Different label to verify override
+
+ # Create mock tool
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = [base_param1, base_param2]
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = [runtime_param1]
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.author == "test_author"
+ assert result.name == "test_tool"
+ assert result.parameters is not None
+ assert len(result.parameters) == 2
+
+ # Find the overridden parameter
+ overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
+ assert overridden_param is not None
+ assert overridden_param.label == "Runtime Param 1" # Should be runtime version
+
+ # Find the non-overridden parameter
+ original_param = next((p for p in result.parameters if p.name == "param2"), None)
+ assert original_param is not None
+ assert original_param.label == "Base Param 2" # Should be base version
+
+ def test_convert_tool_with_additional_runtime_parameters(self):
+ """Test that additional runtime parameters are added to the final list"""
+ # Create mock base parameters
+ base_param1 = Mock(spec=ToolParameter)
+ base_param1.name = "param1"
+ base_param1.form = ToolParameter.ToolParameterForm.FORM
+ base_param1.type = "string"
+ base_param1.label = "Base Param 1"
+
+ # Create mock runtime parameters - one that overrides and one that's new
+ runtime_param1 = Mock(spec=ToolParameter)
+ runtime_param1.name = "param1"
+ runtime_param1.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param1.type = "string"
+ runtime_param1.label = "Runtime Param 1"
+
+ runtime_param2 = Mock(spec=ToolParameter)
+ runtime_param2.name = "runtime_only"
+ runtime_param2.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param2.type = "string"
+ runtime_param2.label = "Runtime Only Param"
+
+ # Create mock tool
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = [base_param1]
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 2
+
+ # Check that both parameters are present
+ param_names = [p.name for p in result.parameters]
+ assert "param1" in param_names
+ assert "runtime_only" in param_names
+
+ # Verify the overridden parameter has runtime version
+ overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
+ assert overridden_param is not None
+ assert overridden_param.label == "Runtime Param 1"
+
+ # Verify the new runtime parameter is included
+ new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
+ assert new_param is not None
+ assert new_param.label == "Runtime Only Param"
+
+ def test_convert_tool_with_non_form_runtime_parameters(self):
+ """Test that non-FORM runtime parameters are not added as new parameters"""
+ # Create mock base parameters
+ base_param1 = Mock(spec=ToolParameter)
+ base_param1.name = "param1"
+ base_param1.form = ToolParameter.ToolParameterForm.FORM
+ base_param1.type = "string"
+ base_param1.label = "Base Param 1"
+
+ # Create mock runtime parameters with different forms
+ runtime_param1 = Mock(spec=ToolParameter)
+ runtime_param1.name = "param1"
+ runtime_param1.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param1.type = "string"
+ runtime_param1.label = "Runtime Param 1"
+
+ runtime_param2 = Mock(spec=ToolParameter)
+ runtime_param2.name = "llm_param"
+ runtime_param2.form = ToolParameter.ToolParameterForm.LLM
+ runtime_param2.type = "string"
+ runtime_param2.label = "LLM Param"
+
+ # Create mock tool
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = [base_param1]
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 1 # Only the FORM parameter should be present
+
+ # Check that only the FORM parameter is present
+ param_names = [p.name for p in result.parameters]
+ assert "param1" in param_names
+ assert "llm_param" not in param_names
+
+ def test_convert_tool_with_empty_parameters(self):
+ """Test conversion with empty base and runtime parameters"""
+ # Create mock tool with no parameters
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = []
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = []
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 0
+
+ def test_convert_tool_with_none_parameters(self):
+ """Test conversion when base parameters is None"""
+ # Create mock tool with None parameters
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = None
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = []
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 0
+
+ def test_convert_tool_parameter_order_preserved(self):
+ """Test that parameter order is preserved correctly"""
+ # Create mock base parameters in specific order
+ base_param1 = Mock(spec=ToolParameter)
+ base_param1.name = "param1"
+ base_param1.form = ToolParameter.ToolParameterForm.FORM
+ base_param1.type = "string"
+ base_param1.label = "Base Param 1"
+
+ base_param2 = Mock(spec=ToolParameter)
+ base_param2.name = "param2"
+ base_param2.form = ToolParameter.ToolParameterForm.FORM
+ base_param2.type = "string"
+ base_param2.label = "Base Param 2"
+
+ base_param3 = Mock(spec=ToolParameter)
+ base_param3.name = "param3"
+ base_param3.form = ToolParameter.ToolParameterForm.FORM
+ base_param3.type = "string"
+ base_param3.label = "Base Param 3"
+
+ # Create runtime parameter that overrides middle parameter
+ runtime_param2 = Mock(spec=ToolParameter)
+ runtime_param2.name = "param2"
+ runtime_param2.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param2.type = "string"
+ runtime_param2.label = "Runtime Param 2"
+
+ # Create new runtime parameter
+ runtime_param4 = Mock(spec=ToolParameter)
+ runtime_param4.name = "param4"
+ runtime_param4.form = ToolParameter.ToolParameterForm.FORM
+ runtime_param4.type = "string"
+ runtime_param4.label = "Runtime Param 4"
+
+ # Create mock tool
+ mock_tool = Mock(spec=Tool)
+ mock_tool.entity = Mock()
+ mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
+ mock_tool.entity.identity = Mock()
+ mock_tool.entity.identity.author = "test_author"
+ mock_tool.entity.identity.name = "test_tool"
+ mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
+ mock_tool.entity.description = Mock()
+ mock_tool.entity.description.human = I18nObject(en_US="Test description")
+ mock_tool.entity.output_schema = {}
+ mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
+
+ # Mock fork_tool_runtime to return the same tool
+ mock_tool.fork_tool_runtime.return_value = mock_tool
+
+ # Call the method
+ result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
+
+ # Verify the result
+ assert isinstance(result, ToolApiEntity)
+ assert result.parameters is not None
+ assert len(result.parameters) == 4
+
+ # Check that order is maintained: base parameters first, then new runtime parameters
+ param_names = [p.name for p in result.parameters]
+ assert param_names == ["param1", "param2", "param3", "param4"]
+
+ # Verify that param2 was overridden with runtime version
+ param2 = result.parameters[1]
+ assert param2.name == "param2"
+ assert param2.label == "Runtime Param 2"
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
index 223020c2c5..dfe325648d 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
@@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
@pytest.fixture
def workflow_setup():
- workflow_service = WorkflowService()
+ mock_session_maker = MagicMock()
+ workflow_service = WorkflowService(mock_session_maker)
session = MagicMock(spec=Session)
tenant_id = "test-tenant-id"
workflow_id = "test-workflow-id"
@@ -42,7 +43,7 @@ def test_delete_workflow_success(workflow_setup):
# Setup mocks
# Mock the tool provider query to return None (not published as a tool)
- workflow_setup["session"].query.return_value.filter.return_value.first.return_value = None
+ workflow_setup["session"].query.return_value.where.return_value.first.return_value = None
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
@@ -105,7 +106,7 @@ def test_delete_workflow_published_as_tool_error(workflow_setup):
# Mock the tool provider query
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
- workflow_setup["session"].query.return_value.filter.return_value.first.return_value = mock_tool_provider
+ workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py
index c5c9cf1050..8b1348b75b 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py
@@ -1,14 +1,14 @@
import dataclasses
import secrets
-from unittest import mock
-from unittest.mock import Mock, patch
+from unittest.mock import MagicMock, Mock, patch
import pytest
+from sqlalchemy import Engine
from sqlalchemy.orm import Session
from core.variables import StringSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
-from core.workflow.nodes import NodeType
+from core.workflow.nodes.enums import NodeType
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from services.workflow_draft_variable_service import (
@@ -18,13 +18,25 @@ from services.workflow_draft_variable_service import (
)
+@pytest.fixture
+def mock_engine() -> Engine:
+ return Mock(spec=Engine)
+
+
+@pytest.fixture
+def mock_session(mock_engine) -> Session:
+ mock_session = Mock(spec=Session)
+ mock_session.get_bind.return_value = mock_engine
+ return mock_session
+
+
class TestDraftVariableSaver:
def _get_test_app_id(self):
suffix = secrets.token_hex(6)
return f"test_app_id_{suffix}"
def test__should_variable_be_visible(self):
- mock_session = mock.MagicMock(spec=Session)
+ mock_session = MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -70,7 +82,7 @@ class TestDraftVariableSaver:
),
]
- mock_session = mock.MagicMock(spec=Session)
+ mock_session = MagicMock(spec=Session)
test_app_id = self._get_test_app_id()
saver = DraftVariableSaver(
session=mock_session,
@@ -105,9 +117,8 @@ class TestWorkflowDraftVariableService:
conversation_variables=[],
)
- def test_reset_conversation_variable(self):
+ def test_reset_conversation_variable(self, mock_session):
"""Test resetting a conversation variable"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -131,9 +142,8 @@ class TestWorkflowDraftVariableService:
mock_reset_conv.assert_called_once_with(workflow, variable)
assert result == expected_result
- def test_reset_node_variable_with_no_execution_id(self):
+ def test_reset_node_variable_with_no_execution_id(self, mock_session):
"""Test resetting a node variable with no execution ID - should delete variable"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -158,11 +168,26 @@ class TestWorkflowDraftVariableService:
mock_session.flush.assert_called_once()
assert result is None
- def test_reset_node_variable_with_missing_execution_record(self):
+ def test_reset_node_variable_with_missing_execution_record(
+ self,
+ mock_engine,
+ mock_session,
+ monkeypatch,
+ ):
"""Test resetting a node variable when execution record doesn't exist"""
- mock_session = Mock(spec=Session)
+ mock_repo_session = Mock(spec=Session)
+
+ mock_session_maker = MagicMock()
+ # Mock the context manager protocol for sessionmaker
+ mock_session_maker.return_value.__enter__.return_value = mock_repo_session
+ mock_session_maker.return_value.__exit__.return_value = None
+ monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
service = WorkflowDraftVariableService(mock_session)
+ # Mock the repository to return None (no execution record found)
+ service._api_node_execution_repo = Mock()
+ service._api_node_execution_repo.get_execution_by_id.return_value = None
+
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@@ -171,24 +196,41 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
-
- # Mock session.scalars to return None (no execution record found)
- mock_scalars = Mock()
- mock_scalars.first.return_value = None
- mock_session.scalars.return_value = mock_scalars
+ # Variable is editable by default from factory method
result = service._reset_node_var_or_sys_var(workflow, variable)
+ mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
# Should delete the variable and return None
mock_session.delete.assert_called_once_with(instance=variable)
mock_session.flush.assert_called_once()
assert result is None
- def test_reset_node_variable_with_valid_execution_record(self):
+ def test_reset_node_variable_with_valid_execution_record(
+ self,
+ mock_session,
+ monkeypatch,
+ ):
"""Test resetting a node variable with valid execution record - should restore from execution"""
- mock_session = Mock(spec=Session)
+ mock_repo_session = Mock(spec=Session)
+
+ mock_session_maker = MagicMock()
+ # Mock the context manager protocol for sessionmaker
+ mock_session_maker.return_value.__enter__.return_value = mock_repo_session
+ mock_session_maker.return_value.__exit__.return_value = None
+ mock_session_maker = monkeypatch.setattr(
+ "services.workflow_draft_variable_service.sessionmaker", mock_session_maker
+ )
service = WorkflowDraftVariableService(mock_session)
+ # Create mock execution record
+ mock_execution = Mock(spec=WorkflowNodeExecutionModel)
+ mock_execution.outputs_dict = {"test_var": "output_value"}
+
+ # Mock the repository to return the execution record
+ service._api_node_execution_repo = Mock()
+ service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
+
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@@ -197,16 +239,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
-
- # Create mock execution record
- mock_execution = Mock(spec=WorkflowNodeExecutionModel)
- mock_execution.process_data_dict = {"test_var": "process_value"}
- mock_execution.outputs_dict = {"test_var": "output_value"}
-
- # Mock session.scalars to return the execution record
- mock_scalars = Mock()
- mock_scalars.first.return_value = mock_execution
- mock_session.scalars.return_value = mock_scalars
+ # Variable is editable by default from factory method
# Mock workflow methods
mock_node_config = {"type": "test_node"}
@@ -224,9 +257,8 @@ class TestWorkflowDraftVariableService:
# Should return the updated variable
assert result == variable
- def test_reset_non_editable_system_variable_raises_error(self):
+ def test_reset_non_editable_system_variable_raises_error(self, mock_session):
"""Test that resetting a non-editable system variable raises an error"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -242,24 +274,13 @@ class TestWorkflowDraftVariableService:
editable=False, # Non-editable system variable
)
- # Mock the service to properly check system variable editability
- with patch.object(service, "reset_variable") as mock_reset:
+ with pytest.raises(VariableResetError) as exc_info:
+ service.reset_variable(workflow, variable)
+ assert "cannot reset system variable" in str(exc_info.value)
+ assert f"variable_id={variable.id}" in str(exc_info.value)
- def side_effect(wf, var):
- if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name):
- raise VariableResetError(f"cannot reset system variable, variable_id={var.id}")
- return var
-
- mock_reset.side_effect = side_effect
-
- with pytest.raises(VariableResetError) as exc_info:
- service.reset_variable(workflow, variable)
- assert "cannot reset system variable" in str(exc_info.value)
- assert f"variable_id={variable.id}" in str(exc_info.value)
-
- def test_reset_editable_system_variable_succeeds(self):
+ def test_reset_editable_system_variable_succeeds(self, mock_session):
"""Test that resetting an editable system variable succeeds"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -279,10 +300,9 @@ class TestWorkflowDraftVariableService:
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.files": "[]"}
- # Mock session.scalars to return the execution record
- mock_scalars = Mock()
- mock_scalars.first.return_value = mock_execution
- mock_session.scalars.return_value = mock_scalars
+ # Mock the repository to return the execution record
+ service._api_node_execution_repo = Mock()
+ service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
@@ -291,9 +311,8 @@ class TestWorkflowDraftVariableService:
assert variable.last_edited_at is None
mock_session.flush.assert_called()
- def test_reset_query_system_variable_succeeds(self):
+ def test_reset_query_system_variable_succeeds(self, mock_session):
"""Test that resetting query system variable (another editable one) succeeds"""
- mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
test_app_id = self._get_test_app_id()
@@ -313,10 +332,9 @@ class TestWorkflowDraftVariableService:
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.outputs_dict = {"sys.query": "reset query"}
- # Mock session.scalars to return the execution record
- mock_scalars = Mock()
- mock_scalars.first.return_value = mock_execution
- mock_session.scalars.return_value = mock_scalars
+ # Mock the repository to return the execution record
+ service._api_node_execution_repo = Mock()
+ service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
result = service._reset_node_var_or_sys_var(workflow, variable)
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py
new file mode 100644
index 0000000000..32d2f8b7e0
--- /dev/null
+++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py
@@ -0,0 +1,288 @@
+from datetime import datetime
+from unittest.mock import MagicMock
+from uuid import uuid4
+
+import pytest
+from sqlalchemy.orm import Session
+
+from models.workflow import WorkflowNodeExecutionModel
+from repositories.sqlalchemy_api_workflow_node_execution_repository import (
+ DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
+)
+
+
+class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
+ @pytest.fixture
+ def repository(self):
+ mock_session_maker = MagicMock()
+ return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
+
+ @pytest.fixture
+ def mock_execution(self):
+ execution = MagicMock(spec=WorkflowNodeExecutionModel)
+ execution.id = str(uuid4())
+ execution.tenant_id = "tenant-123"
+ execution.app_id = "app-456"
+ execution.workflow_id = "workflow-789"
+ execution.workflow_run_id = "run-101"
+ execution.node_id = "node-202"
+ execution.index = 1
+ execution.created_at = "2023-01-01T00:00:00Z"
+ return execution
+
+ def test_get_node_last_execution_found(self, repository, mock_execution):
+ """Test getting the last execution for a node when it exists."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = mock_execution
+
+ # Act
+ result = repository.get_node_last_execution(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ workflow_id="workflow-789",
+ node_id="node-202",
+ )
+
+ # Assert
+ assert result == mock_execution
+ mock_session.scalar.assert_called_once()
+ # Verify the query was constructed correctly
+ call_args = mock_session.scalar.call_args[0][0]
+ assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
+
+ def test_get_node_last_execution_not_found(self, repository):
+ """Test getting the last execution for a node when it doesn't exist."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = None
+
+ # Act
+ result = repository.get_node_last_execution(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ workflow_id="workflow-789",
+ node_id="node-202",
+ )
+
+ # Assert
+ assert result is None
+ mock_session.scalar.assert_called_once()
+
+ def test_get_executions_by_workflow_run(self, repository, mock_execution):
+ """Test getting all executions for a workflow run."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ executions = [mock_execution]
+ mock_session.execute.return_value.scalars.return_value.all.return_value = executions
+
+ # Act
+ result = repository.get_executions_by_workflow_run(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ workflow_run_id="run-101",
+ )
+
+ # Assert
+ assert result == executions
+ mock_session.execute.assert_called_once()
+ # Verify the query was constructed correctly
+ call_args = mock_session.execute.call_args[0][0]
+ assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
+
+ def test_get_executions_by_workflow_run_empty(self, repository):
+ """Test getting executions for a workflow run when none exist."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.execute.return_value.scalars.return_value.all.return_value = []
+
+ # Act
+ result = repository.get_executions_by_workflow_run(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ workflow_run_id="run-101",
+ )
+
+ # Assert
+ assert result == []
+ mock_session.execute.assert_called_once()
+
+ def test_get_execution_by_id_found(self, repository, mock_execution):
+ """Test getting execution by ID when it exists."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = mock_execution
+
+ # Act
+ result = repository.get_execution_by_id(mock_execution.id)
+
+ # Assert
+ assert result == mock_execution
+ mock_session.scalar.assert_called_once()
+
+ def test_get_execution_by_id_not_found(self, repository):
+ """Test getting execution by ID when it doesn't exist."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+ mock_session.scalar.return_value = None
+
+ # Act
+ result = repository.get_execution_by_id("non-existent-id")
+
+ # Assert
+ assert result is None
+ mock_session.scalar.assert_called_once()
+
+ def test_repository_implements_protocol(self, repository):
+ """Test that the repository implements the required protocol methods."""
+ # Verify all protocol methods are implemented
+ assert hasattr(repository, "get_node_last_execution")
+ assert hasattr(repository, "get_executions_by_workflow_run")
+ assert hasattr(repository, "get_execution_by_id")
+
+ # Verify methods are callable
+ assert callable(repository.get_node_last_execution)
+ assert callable(repository.get_executions_by_workflow_run)
+ assert callable(repository.get_execution_by_id)
+ assert callable(repository.delete_expired_executions)
+ assert callable(repository.delete_executions_by_app)
+ assert callable(repository.get_expired_executions_batch)
+ assert callable(repository.delete_executions_by_ids)
+
+ def test_delete_expired_executions(self, repository):
+ """Test deleting expired executions."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Mock the select query to return some IDs first time, then empty to stop loop
+ execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
+
+ # Mock execute method to handle both select and delete statements
+ def mock_execute(stmt):
+ mock_result = MagicMock()
+ # For select statements, return execution IDs
+ if hasattr(stmt, "limit"): # This is our select statement
+ mock_result.scalars.return_value.all.return_value = execution_ids
+ else: # This is our delete statement
+ mock_result.rowcount = 2
+ return mock_result
+
+ mock_session.execute.side_effect = mock_execute
+
+ before_date = datetime(2023, 1, 1)
+
+ # Act
+ result = repository.delete_expired_executions(
+ tenant_id="tenant-123",
+ before_date=before_date,
+ batch_size=1000,
+ )
+
+ # Assert
+ assert result == 2
+ assert mock_session.execute.call_count == 2 # One select call, one delete call
+ mock_session.commit.assert_called_once()
+
+ def test_delete_executions_by_app(self, repository):
+ """Test deleting executions by app."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Mock the select query to return some IDs first time, then empty to stop loop
+ execution_ids = ["id1", "id2"]
+
+ # Mock execute method to handle both select and delete statements
+ def mock_execute(stmt):
+ mock_result = MagicMock()
+ # For select statements, return execution IDs
+ if hasattr(stmt, "limit"): # This is our select statement
+ mock_result.scalars.return_value.all.return_value = execution_ids
+ else: # This is our delete statement
+ mock_result.rowcount = 2
+ return mock_result
+
+ mock_session.execute.side_effect = mock_execute
+
+ # Act
+ result = repository.delete_executions_by_app(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ batch_size=1000,
+ )
+
+ # Assert
+ assert result == 2
+ assert mock_session.execute.call_count == 2 # One select call, one delete call
+ mock_session.commit.assert_called_once()
+
+ def test_get_expired_executions_batch(self, repository):
+ """Test getting expired executions batch for backup."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Create mock execution objects
+ mock_execution1 = MagicMock()
+ mock_execution1.id = "exec-1"
+ mock_execution2 = MagicMock()
+ mock_execution2.id = "exec-2"
+
+ mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
+
+ before_date = datetime(2023, 1, 1)
+
+ # Act
+ result = repository.get_expired_executions_batch(
+ tenant_id="tenant-123",
+ before_date=before_date,
+ batch_size=1000,
+ )
+
+ # Assert
+ assert len(result) == 2
+ assert result[0].id == "exec-1"
+ assert result[1].id == "exec-2"
+ mock_session.execute.assert_called_once()
+
+ def test_delete_executions_by_ids(self, repository):
+ """Test deleting executions by IDs."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Mock the delete query result
+ mock_result = MagicMock()
+ mock_result.rowcount = 3
+ mock_session.execute.return_value = mock_result
+
+ execution_ids = ["id1", "id2", "id3"]
+
+ # Act
+ result = repository.delete_executions_by_ids(execution_ids)
+
+ # Assert
+ assert result == 3
+ mock_session.execute.assert_called_once()
+ mock_session.commit.assert_called_once()
+
+ def test_delete_executions_by_ids_empty_list(self, repository):
+ """Test deleting executions with empty ID list."""
+ # Arrange
+ mock_session = MagicMock(spec=Session)
+ repository._session_maker.return_value.__enter__.return_value = mock_session
+
+ # Act
+ result = repository.delete_executions_by_ids([])
+
+ # Assert
+ assert result == 0
+ mock_session.query.assert_not_called()
+ mock_session.commit.assert_not_called()
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py
index 13393668ea..9700cbaf0e 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_service.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py
@@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
- return WorkflowService()
+ mock_session_maker = MagicMock()
+ return WorkflowService(mock_session_maker)
@pytest.fixture
def mock_app(self):
diff --git a/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py
new file mode 100644
index 0000000000..30990f8d50
--- /dev/null
+++ b/api/tests/unit_tests/utils/oauth_encryption/test_system_oauth_encryption.py
@@ -0,0 +1,619 @@
+import base64
+import hashlib
+from unittest.mock import patch
+
+import pytest
+from Crypto.Cipher import AES
+from Crypto.Random import get_random_bytes
+from Crypto.Util.Padding import pad
+
+from core.tools.utils.system_oauth_encryption import (
+ OAuthEncryptionError,
+ SystemOAuthEncrypter,
+ create_system_oauth_encrypter,
+ decrypt_system_oauth_params,
+ encrypt_system_oauth_params,
+ get_system_oauth_encrypter,
+)
+
+
+class TestSystemOAuthEncrypter:
+ """Test cases for SystemOAuthEncrypter class"""
+
+ def test_init_with_secret_key(self):
+ """Test initialization with provided secret key"""
+ secret_key = "test_secret_key"
+ encrypter = SystemOAuthEncrypter(secret_key=secret_key)
+ expected_key = hashlib.sha256(secret_key.encode()).digest()
+ assert encrypter.key == expected_key
+
+ def test_init_with_none_secret_key(self):
+ """Test initialization with None secret key falls back to config"""
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "config_secret"
+ encrypter = SystemOAuthEncrypter(secret_key=None)
+ expected_key = hashlib.sha256(b"config_secret").digest()
+ assert encrypter.key == expected_key
+
+ def test_init_with_empty_secret_key(self):
+ """Test initialization with empty secret key"""
+ encrypter = SystemOAuthEncrypter(secret_key="")
+ expected_key = hashlib.sha256(b"").digest()
+ assert encrypter.key == expected_key
+
+ def test_init_without_secret_key_uses_config(self):
+ """Test initialization without secret key uses config"""
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "default_secret"
+ encrypter = SystemOAuthEncrypter()
+ expected_key = hashlib.sha256(b"default_secret").digest()
+ assert encrypter.key == expected_key
+
+ def test_encrypt_oauth_params_basic(self):
+ """Test basic OAuth parameters encryption"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+ # Should be valid base64
+ try:
+ base64.b64decode(encrypted)
+ except Exception:
+ pytest.fail("Encrypted result is not valid base64")
+
+ def test_encrypt_oauth_params_empty_dict(self):
+ """Test encryption with empty dictionary"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {}
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_encrypt_oauth_params_complex_data(self):
+ """Test encryption with complex data structures"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {
+ "client_id": "test_id",
+ "client_secret": "test_secret",
+ "scopes": ["read", "write", "admin"],
+ "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
+ "numeric_value": 42,
+ "boolean_value": False,
+ "null_value": None,
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_encrypt_oauth_params_unicode_data(self):
+ """Test encryption with unicode data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_encrypt_oauth_params_large_data(self):
+ """Test encryption with large data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {
+ "client_id": "test_id",
+ "large_data": "x" * 10000, # 10KB of data
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_encrypt_oauth_params_invalid_input(self):
+ """Test encryption with invalid input types"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ with pytest.raises(Exception): # noqa: B017
+ encrypter.encrypt_oauth_params(None) # type: ignore
+
+ with pytest.raises(Exception): # noqa: B017
+ encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
+
+ def test_decrypt_oauth_params_basic(self):
+ """Test basic OAuth parameters decryption"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_empty_dict(self):
+ """Test decryption of empty dictionary"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {}
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_complex_data(self):
+ """Test decryption with complex data structures"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {
+ "client_id": "test_id",
+ "client_secret": "test_secret",
+ "scopes": ["read", "write", "admin"],
+ "metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
+ "numeric_value": 42,
+ "boolean_value": False,
+ "null_value": None,
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_unicode_data(self):
+ """Test decryption with unicode data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {
+ "client_id": "test_id",
+ "client_secret": "test_secret",
+ "description": "This is a test case 🚀",
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_large_data(self):
+ """Test decryption with large data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ original_params = {
+ "client_id": "test_id",
+ "large_data": "x" * 10000, # 10KB of data
+ }
+
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+
+ assert decrypted == original_params
+
+ def test_decrypt_oauth_params_invalid_base64(self):
+ """Test decryption with invalid base64 data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ with pytest.raises(OAuthEncryptionError):
+ encrypter.decrypt_oauth_params("invalid_base64!")
+
+ def test_decrypt_oauth_params_empty_string(self):
+ """Test decryption with empty string"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params("")
+
+ assert "encrypted_data cannot be empty" in str(exc_info.value)
+
+ def test_decrypt_oauth_params_non_string_input(self):
+ """Test decryption with non-string input"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params(123) # type: ignore
+
+ assert "encrypted_data must be a string" in str(exc_info.value)
+
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params(None) # type: ignore
+
+ assert "encrypted_data must be a string" in str(exc_info.value)
+
+ def test_decrypt_oauth_params_too_short_data(self):
+ """Test decryption with too short encrypted data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ # Create data that's too short (less than 32 bytes)
+ short_data = base64.b64encode(b"short").decode()
+
+ with pytest.raises(OAuthEncryptionError) as exc_info:
+ encrypter.decrypt_oauth_params(short_data)
+
+ assert "Invalid encrypted data format" in str(exc_info.value)
+
+ def test_decrypt_oauth_params_corrupted_data(self):
+ """Test decryption with corrupted data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ # Create corrupted data (valid base64 but invalid encrypted content)
+ corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
+
+ with pytest.raises(OAuthEncryptionError):
+ encrypter.decrypt_oauth_params(corrupted_data)
+
+ def test_decrypt_oauth_params_wrong_key(self):
+ """Test decryption with wrong key"""
+ encrypter1 = SystemOAuthEncrypter("secret1")
+ encrypter2 = SystemOAuthEncrypter("secret2")
+
+ original_params = {"client_id": "test_id", "client_secret": "test_secret"}
+ encrypted = encrypter1.encrypt_oauth_params(original_params)
+
+ with pytest.raises(OAuthEncryptionError):
+ encrypter2.decrypt_oauth_params(encrypted)
+
+ def test_encryption_decryption_consistency(self):
+ """Test that encryption and decryption are consistent"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ test_cases = [
+ {},
+ {"simple": "value"},
+ {"client_id": "id", "client_secret": "secret"},
+ {"complex": {"nested": {"deep": "value"}}},
+ {"unicode": "test 🚀"},
+ {"numbers": 42, "boolean": True, "null": None},
+ {"array": [1, 2, 3, "four", {"five": 5}]},
+ ]
+
+ for original_params in test_cases:
+ encrypted = encrypter.encrypt_oauth_params(original_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == original_params, f"Failed for case: {original_params}"
+
+ def test_encryption_randomness(self):
+ """Test that encryption produces different results for same input"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
+ encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
+
+ # Should be different due to random IV
+ assert encrypted1 != encrypted2
+
+ # But should decrypt to same result
+ decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
+ decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
+ assert decrypted1 == decrypted2 == oauth_params
+
+ def test_different_secret_keys_produce_different_results(self):
+ """Test that different secret keys produce different encrypted results"""
+ encrypter1 = SystemOAuthEncrypter("secret1")
+ encrypter2 = SystemOAuthEncrypter("secret2")
+
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
+ encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
+
+ # Should produce different encrypted results
+ assert encrypted1 != encrypted2
+
+ # But each should decrypt correctly with its own key
+ decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
+ decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
+ assert decrypted1 == decrypted2 == oauth_params
+
+ @patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
+ def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
+ """Test encryption when crypto operation fails"""
+ mock_get_random_bytes.side_effect = Exception("Crypto error")
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id"}
+
+ with pytest.raises(OAuthEncryptionError) as exc_info:
+ encrypter.encrypt_oauth_params(oauth_params)
+
+ assert "Encryption failed" in str(exc_info.value)
+
+ @patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
+ def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
+ """Test encryption when JSON serialization fails"""
+ mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id"}
+
+ with pytest.raises(OAuthEncryptionError) as exc_info:
+ encrypter.encrypt_oauth_params(oauth_params)
+
+ assert "Encryption failed" in str(exc_info.value)
+
+ def test_decrypt_oauth_params_invalid_json(self):
+ """Test decryption with invalid JSON data"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ # Create valid encrypted data but with invalid JSON content
+ iv = get_random_bytes(16)
+ cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
+ invalid_json = b"invalid json content"
+ padded_data = pad(invalid_json, AES.block_size)
+ encrypted_data = cipher.encrypt(padded_data)
+ combined = iv + encrypted_data
+ encoded = base64.b64encode(combined).decode()
+
+ with pytest.raises(OAuthEncryptionError):
+ encrypter.decrypt_oauth_params(encoded)
+
+ def test_key_derivation_consistency(self):
+ """Test that key derivation is consistent"""
+ secret_key = "test_secret"
+ encrypter1 = SystemOAuthEncrypter(secret_key)
+ encrypter2 = SystemOAuthEncrypter(secret_key)
+
+ assert encrypter1.key == encrypter2.key
+
+ # Keys should be 32 bytes (256 bits)
+ assert len(encrypter1.key) == 32
+
+
+class TestFactoryFunctions:
+ """Test cases for factory functions"""
+
+ def test_create_system_oauth_encrypter_with_secret(self):
+ """Test factory function with secret key"""
+ secret_key = "test_secret"
+ encrypter = create_system_oauth_encrypter(secret_key)
+
+ assert isinstance(encrypter, SystemOAuthEncrypter)
+ expected_key = hashlib.sha256(secret_key.encode()).digest()
+ assert encrypter.key == expected_key
+
+ def test_create_system_oauth_encrypter_without_secret(self):
+ """Test factory function without secret key"""
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "config_secret"
+ encrypter = create_system_oauth_encrypter()
+
+ assert isinstance(encrypter, SystemOAuthEncrypter)
+ expected_key = hashlib.sha256(b"config_secret").digest()
+ assert encrypter.key == expected_key
+
+ def test_create_system_oauth_encrypter_with_none_secret(self):
+ """Test factory function with None secret key"""
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "config_secret"
+ encrypter = create_system_oauth_encrypter(None)
+
+ assert isinstance(encrypter, SystemOAuthEncrypter)
+ expected_key = hashlib.sha256(b"config_secret").digest()
+ assert encrypter.key == expected_key
+
+
+class TestGlobalEncrypterInstance:
+ """Test cases for global encrypter instance"""
+
+ def test_get_system_oauth_encrypter_singleton(self):
+ """Test that get_system_oauth_encrypter returns singleton instance"""
+ # Clear the global instance first
+ import core.tools.utils.system_oauth_encryption
+
+ core.tools.utils.system_oauth_encryption._oauth_encrypter = None
+
+ encrypter1 = get_system_oauth_encrypter()
+ encrypter2 = get_system_oauth_encrypter()
+
+ assert encrypter1 is encrypter2
+ assert isinstance(encrypter1, SystemOAuthEncrypter)
+
+ def test_get_system_oauth_encrypter_uses_config(self):
+ """Test that global encrypter uses config"""
+ # Clear the global instance first
+ import core.tools.utils.system_oauth_encryption
+
+ core.tools.utils.system_oauth_encryption._oauth_encrypter = None
+
+ with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
+ mock_config.SECRET_KEY = "global_secret"
+ encrypter = get_system_oauth_encrypter()
+
+ expected_key = hashlib.sha256(b"global_secret").digest()
+ assert encrypter.key == expected_key
+
+
+class TestConvenienceFunctions:
+ """Test cases for convenience functions"""
+
+ def test_encrypt_system_oauth_params(self):
+ """Test encrypt_system_oauth_params convenience function"""
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted = encrypt_system_oauth_params(oauth_params)
+
+ assert isinstance(encrypted, str)
+ assert len(encrypted) > 0
+
+ def test_decrypt_system_oauth_params(self):
+ """Test decrypt_system_oauth_params convenience function"""
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ encrypted = encrypt_system_oauth_params(oauth_params)
+ decrypted = decrypt_system_oauth_params(encrypted)
+
+ assert decrypted == oauth_params
+
+ def test_convenience_functions_consistency(self):
+ """Test that convenience functions work consistently"""
+ test_cases = [
+ {},
+ {"simple": "value"},
+ {"client_id": "id", "client_secret": "secret"},
+ {"complex": {"nested": {"deep": "value"}}},
+ {"unicode": "test 🚀"},
+ {"numbers": 42, "boolean": True, "null": None},
+ ]
+
+ for original_params in test_cases:
+ encrypted = encrypt_system_oauth_params(original_params)
+ decrypted = decrypt_system_oauth_params(encrypted)
+ assert decrypted == original_params, f"Failed for case: {original_params}"
+
+ def test_convenience_functions_with_errors(self):
+ """Test convenience functions with error conditions"""
+ # Test encryption with invalid input
+ with pytest.raises(Exception): # noqa: B017
+ encrypt_system_oauth_params(None) # type: ignore
+
+ # Test decryption with invalid input
+ with pytest.raises(ValueError):
+ decrypt_system_oauth_params("")
+
+ with pytest.raises(ValueError):
+ decrypt_system_oauth_params(None) # type: ignore
+
+
+class TestErrorHandling:
+ """Test cases for error handling"""
+
+ def test_oauth_encryption_error_inheritance(self):
+ """Test that OAuthEncryptionError is a proper exception"""
+ error = OAuthEncryptionError("Test error")
+ assert isinstance(error, Exception)
+ assert str(error) == "Test error"
+
+ def test_oauth_encryption_error_with_cause(self):
+ """Test OAuthEncryptionError with cause"""
+ original_error = ValueError("Original error")
+ error = OAuthEncryptionError("Wrapper error")
+ error.__cause__ = original_error
+
+ assert isinstance(error, Exception)
+ assert str(error) == "Wrapper error"
+ assert error.__cause__ is original_error
+
+ def test_error_messages_are_informative(self):
+ """Test that error messages are informative"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+
+ # Test empty string error
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params("")
+ assert "encrypted_data cannot be empty" in str(exc_info.value)
+
+ # Test non-string error
+ with pytest.raises(ValueError) as exc_info:
+ encrypter.decrypt_oauth_params(123) # type: ignore
+ assert "encrypted_data must be a string" in str(exc_info.value)
+
+ # Test invalid format error
+ short_data = base64.b64encode(b"short").decode()
+ with pytest.raises(OAuthEncryptionError) as exc_info:
+ encrypter.decrypt_oauth_params(short_data)
+ assert "Invalid encrypted data format" in str(exc_info.value)
+
+
+class TestEdgeCases:
+ """Test cases for edge cases and boundary conditions"""
+
+ def test_very_long_secret_key(self):
+ """Test with very long secret key"""
+ long_secret = "x" * 10000
+ encrypter = SystemOAuthEncrypter(long_secret)
+
+ # Key should still be 32 bytes due to SHA-256
+ assert len(encrypter.key) == 32
+
+ # Should still work normally
+ oauth_params = {"client_id": "test_id"}
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_special_characters_in_secret_key(self):
+ """Test with special characters in secret key"""
+ special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
+ encrypter = SystemOAuthEncrypter(special_secret)
+
+ oauth_params = {"client_id": "test_id"}
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_empty_values_in_oauth_params(self):
+ """Test with empty values in oauth params"""
+ oauth_params = {
+ "client_id": "",
+ "client_secret": "",
+ "empty_dict": {},
+ "empty_list": [],
+ "empty_string": "",
+ "zero": 0,
+ "false": False,
+ "none": None,
+ }
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_deeply_nested_oauth_params(self):
+ """Test with deeply nested oauth params"""
+ oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_oauth_params_with_all_json_types(self):
+ """Test with all JSON-supported data types"""
+ oauth_params = {
+ "string": "test_string",
+ "integer": 42,
+ "float": 3.14159,
+ "boolean_true": True,
+ "boolean_false": False,
+ "null_value": None,
+ "empty_string": "",
+ "array": [1, "two", 3.0, True, False, None],
+ "object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
+ }
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+
+class TestPerformance:
+ """Test cases for performance considerations"""
+
+ def test_large_oauth_params(self):
+ """Test with large oauth params"""
+ large_value = "x" * 100000 # 100KB
+ oauth_params = {"client_id": "test_id", "large_data": large_value}
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_many_fields_oauth_params(self):
+ """Test with many fields in oauth params"""
+ oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
+
+ encrypter = SystemOAuthEncrypter("test_secret")
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
+
+ def test_repeated_encryption_decryption(self):
+ """Test repeated encryption and decryption operations"""
+ encrypter = SystemOAuthEncrypter("test_secret")
+ oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
+
+ # Test multiple rounds of encryption/decryption
+ for i in range(100):
+ encrypted = encrypter.encrypt_oauth_params(oauth_params)
+ decrypted = encrypter.decrypt_oauth_params(encrypted)
+ assert decrypted == oauth_params
diff --git a/api/tests/unit_tests/utils/position_helper/test_position_helper.py b/api/tests/unit_tests/utils/position_helper/test_position_helper.py
index 29558a93c2..dbd8f05098 100644
--- a/api/tests/unit_tests/utils/position_helper/test_position_helper.py
+++ b/api/tests/unit_tests/utils/position_helper/test_position_helper.py
@@ -95,7 +95,7 @@ def test_included_position_data(prepare_example_positions_yaml):
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
pin_list = ["forth", "first"]
include_set = {"forth", "first"}
- exclude_set = {}
+ exclude_set = set()
position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)
diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
index 728c58fc5b..93284eed4b 100644
--- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
+++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py
@@ -27,11 +27,11 @@ def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LL
return LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=Decimal("0.001"),
- prompt_price_unit=Decimal("1"),
+ prompt_price_unit=Decimal(1),
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
completion_tokens=completion_tokens,
completion_unit_price=Decimal("0.002"),
- completion_price_unit=Decimal("1"),
+ completion_price_unit=Decimal(1),
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
total_tokens=prompt_tokens + completion_tokens,
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),
diff --git a/api/uv.lock b/api/uv.lock
index d379f28e52..623b125ab3 100644
--- a/api/uv.lock
+++ b/api/uv.lock
@@ -99,28 +99,29 @@ wheels = [
[[package]]
name = "aiosignal"
-version = "1.3.2"
+version = "1.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "frozenlist" },
+ { name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/ba/b5/6d55e80f6d8a08ce22b982eafa278d823b541c925f11ee774b0b9c43473d/aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54", size = 19424, upload-time = "2024-12-13T17:10:40.86Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597, upload-time = "2024-12-13T17:10:38.469Z" },
+ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
]
[[package]]
name = "alembic"
-version = "1.16.2"
+version = "1.16.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mako" },
{ name = "sqlalchemy" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/9c/35/116797ff14635e496bbda0c168987f5326a6555b09312e9b817e360d1f56/alembic-1.16.2.tar.gz", hash = "sha256:e53c38ff88dadb92eb22f8b150708367db731d58ad7e9d417c9168ab516cbed8", size = 1963563, upload-time = "2025-06-16T18:05:08.566Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/b9/40/28683414cc8711035a65256ca689e159471aa9ef08e8741ad1605bc01066/alembic-1.16.3.tar.gz", hash = "sha256:18ad13c1f40a5796deee4b2346d1a9c382f44b8af98053897484fa6cf88025e4", size = 1967462, upload-time = "2025-07-08T18:57:50.991Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/dd/e2/88e425adac5ad887a087c38d04fe2030010572a3e0e627f8a6e8c33eeda8/alembic-1.16.2-py3-none-any.whl", hash = "sha256:5f42e9bd0afdbd1d5e3ad856c01754530367debdebf21ed6894e34af52b3bb03", size = 242717, upload-time = "2025-06-16T18:05:10.27Z" },
+ { url = "https://files.pythonhosted.org/packages/e6/68/1dea77887af7304528ea944c355d769a7ccc4599d3a23bd39182486deb42/alembic-1.16.3-py3-none-any.whl", hash = "sha256:70a7c7829b792de52d08ca0e3aefaf060687cb8ed6bebfa557e597a1a5e5a481", size = 246933, upload-time = "2025-07-08T18:57:52.793Z" },
]
[[package]]
@@ -243,7 +244,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/22/8a/ef8ddf5ee0350984c
[[package]]
name = "alibabacloud-tea-openapi"
-version = "0.3.15"
+version = "0.3.16"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "alibabacloud-credentials" },
@@ -252,7 +253,7 @@ dependencies = [
{ name = "alibabacloud-tea-util" },
{ name = "alibabacloud-tea-xml" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/be/cb/f1b10b1da37e4c0de2aa9ca1e7153a6960a7f2dc496664e85fdc8b621f84/alibabacloud_tea_openapi-0.3.15.tar.gz", hash = "sha256:56a0aa6d51d8cf18c0cf3d219d861f4697f59d3e17fa6726b1101826d93988a2", size = 13021, upload-time = "2025-05-06T12:56:29.402Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/09/be/f594e79625e5ccfcfe7f12d7d70709a3c59e920878469c998886211c850d/alibabacloud_tea_openapi-0.3.16.tar.gz", hash = "sha256:6bffed8278597592e67860156f424bde4173a6599d7b6039fb640a3612bae292", size = 13087, upload-time = "2025-07-04T09:30:10.689Z" }
[[package]]
name = "alibabacloud-tea-util"
@@ -370,11 +371,11 @@ wheels = [
[[package]]
name = "asgiref"
-version = "3.8.1"
+version = "3.9.1"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/29/38/b3395cc9ad1b56d2ddac9970bc8f4141312dbaec28bc7c218b0dfafd0f42/asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590", size = 35186, upload-time = "2024-03-22T14:39:36.863Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/90/61/0aa957eec22ff70b830b22ff91f825e70e1ef732c06666a805730f28b36b/asgiref-3.9.1.tar.gz", hash = "sha256:a5ab6582236218e5ef1648f242fd9f10626cfd4de8dc377db215d5d5098e3142", size = 36870, upload-time = "2025-07-08T09:07:43.344Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/39/e3/893e8757be2612e6c266d9bb58ad2e3651524b5b40cf56761e985a28b13e/asgiref-3.8.1-py3-none-any.whl", hash = "sha256:3e1e3ecc849832fe52ccf2cb6686b7a55f82bb1d6aee72a58826471390335e47", size = 23828, upload-time = "2024-03-22T14:39:34.521Z" },
+ { url = "https://files.pythonhosted.org/packages/7c/3c/0464dcada90d5da0e71018c04a140ad6349558afb30b3051b4264cc5b965/asgiref-3.9.1-py3-none-any.whl", hash = "sha256:f3bba7092a48005b5f5bacd747d36ee4a5a61f4a269a6df590b43144355ebd2c", size = 23790, upload-time = "2025-07-08T09:07:41.548Z" },
]
[[package]]
@@ -559,16 +560,16 @@ wheels = [
[[package]]
name = "boto3-stubs"
-version = "1.39.2"
+version = "1.39.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "botocore-stubs" },
{ name = "types-s3transfer" },
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/06/09/206a17938bfc7ec6e7c0b13ed58ad78146e46c29436d324ed55ceb5136ed/boto3_stubs-1.39.2.tar.gz", hash = "sha256:b1f1baef1658bd575a29ca85cc0877dbb3adeb376ffa8cbf242b876719ae0f95", size = 99939, upload-time = "2025-07-02T19:28:20.423Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/f0/ea/85b9940d6eedc04d0c6febf24d27311b6ee54f85ccc37192eb4db0dff5d6/boto3_stubs-1.39.3.tar.gz", hash = "sha256:9aad443b1d690951fd9ccb6fa20ad387bd0b1054c704566ff65dd0043a63fc26", size = 99947, upload-time = "2025-07-03T19:28:15.602Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/39/be/9c65f2bfc6df27ec5f16d28c454e2e3cb9a7af3ef8588440658334325a85/boto3_stubs-1.39.2-py3-none-any.whl", hash = "sha256:ce98d96fe1a7177b05067be3cd933277c88f745de836752f9ef8b4286dbfa53b", size = 69196, upload-time = "2025-07-02T19:28:07.025Z" },
+ { url = "https://files.pythonhosted.org/packages/be/b8/0c56297e5f290de17e838c7e4ff338f5b94351c6566aed70ee197a671dc5/boto3_stubs-1.39.3-py3-none-any.whl", hash = "sha256:4daddb19374efa6d1bef7aded9cede0075f380722a9e60ab129ebba14ae66b69", size = 69196, upload-time = "2025-07-03T19:28:09.4Z" },
]
[package.optional-dependencies]
@@ -1216,7 +1217,7 @@ wheels = [
[[package]]
name = "dify-api"
-version = "1.5.1"
+version = "1.7.0"
source = { virtual = "." }
dependencies = [
{ name = "arize-phoenix-otel" },
@@ -1245,6 +1246,7 @@ dependencies = [
{ name = "googleapis-common-protos" },
{ name = "gunicorn" },
{ name = "httpx", extra = ["socks"] },
+ { name = "httpx-sse" },
{ name = "jieba" },
{ name = "json-repair" },
{ name = "langfuse" },
@@ -1289,6 +1291,7 @@ dependencies = [
{ name = "sendgrid" },
{ name = "sentry-sdk", extra = ["flask"] },
{ name = "sqlalchemy" },
+ { name = "sseclient-py" },
{ name = "starlette" },
{ name = "tiktoken" },
{ name = "transformers" },
@@ -1425,6 +1428,7 @@ requires-dist = [
{ name = "googleapis-common-protos", specifier = "==1.63.0" },
{ name = "gunicorn", specifier = "~=23.0.0" },
{ name = "httpx", extras = ["socks"], specifier = "~=0.27.0" },
+ { name = "httpx-sse", specifier = ">=0.4.0" },
{ name = "jieba", specifier = "==0.42.1" },
{ name = "json-repair", specifier = ">=0.41.1" },
{ name = "langfuse", specifier = "~=2.51.3" },
@@ -1469,6 +1473,7 @@ requires-dist = [
{ name = "sendgrid", specifier = "~=6.12.3" },
{ name = "sentry-sdk", extras = ["flask"], specifier = "~=2.28.0" },
{ name = "sqlalchemy", specifier = "~=2.0.29" },
+ { name = "sseclient-py", specifier = ">=1.8.0" },
{ name = "starlette", specifier = "==0.41.0" },
{ name = "tiktoken", specifier = "~=0.9.0" },
{ name = "transformers", specifier = "~=4.51.0" },
@@ -1493,7 +1498,7 @@ dev = [
{ name = "pytest-cov", specifier = "~=4.1.0" },
{ name = "pytest-env", specifier = "~=1.1.3" },
{ name = "pytest-mock", specifier = "~=3.14.0" },
- { name = "ruff", specifier = "~=0.11.5" },
+ { name = "ruff", specifier = "~=0.12.3" },
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
{ name = "types-aiofiles", specifier = "~=24.1.0" },
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
@@ -1708,16 +1713,16 @@ wheels = [
[[package]]
name = "fastapi"
-version = "0.115.14"
+version = "0.116.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic" },
{ name = "starlette" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/ca/53/8c38a874844a8b0fa10dd8adf3836ac154082cf88d3f22b544e9ceea0a15/fastapi-0.115.14.tar.gz", hash = "sha256:b1de15cdc1c499a4da47914db35d0e4ef8f1ce62b624e94e0e5824421df99739", size = 296263, upload-time = "2025-06-26T15:29:08.21Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/20/38/e1da78736143fd885c36213a3ccc493c384ae8fea6a0f0bc272ef42ebea8/fastapi-0.116.0.tar.gz", hash = "sha256:80dc0794627af0390353a6d1171618276616310d37d24faba6648398e57d687a", size = 296518, upload-time = "2025-07-07T15:09:27.82Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/53/50/b1222562c6d270fea83e9c9075b8e8600b8479150a18e4516a6138b980d1/fastapi-0.115.14-py3-none-any.whl", hash = "sha256:6c0c8bf9420bd58f565e585036d971872472b4f7d3f6c73b698e10cffdefb3ca", size = 95514, upload-time = "2025-06-26T15:29:06.49Z" },
+ { url = "https://files.pythonhosted.org/packages/2f/68/d80347fe2360445b5f58cf290e588a4729746e7501080947e6cdae114b1f/fastapi-0.116.0-py3-none-any.whl", hash = "sha256:fdcc9ed272eaef038952923bef2b735c02372402d1203ee1210af4eea7a78d2b", size = 95625, upload-time = "2025-07-07T15:09:26.348Z" },
]
[[package]]
@@ -2532,6 +2537,15 @@ socks = [
{ name = "socksio" },
]
+[[package]]
+name = "httpx-sse"
+version = "0.4.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/6e/fa/66bd985dd0b7c109a3bcb89272ee0bfb7e2b4d06309ad7b38ff866734b2a/httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e", size = 12998, upload-time = "2025-06-24T13:21:05.71Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/25/0a/6269e3473b09aed2dab8aa1a600c70f31f00ae1349bee30658f7e358a159/httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37", size = 8054, upload-time = "2025-06-24T13:21:04.772Z" },
+]
+
[[package]]
name = "huggingface-hub"
version = "0.33.2"
@@ -2574,15 +2588,15 @@ wheels = [
[[package]]
name = "hypothesis"
-version = "6.135.24"
+version = "6.135.26"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "attrs" },
{ name = "sortedcontainers" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/cf/ae/f846b67ce9fc80cf51cece6b7adaa3fe2de4251242d142e241ce5d4aa26f/hypothesis-6.135.24.tar.gz", hash = "sha256:e301aeb2691ec0a1f62bfc405eaa966055d603e328cd854c1ed59e1728e35ab6", size = 454011, upload-time = "2025-07-03T02:46:51.776Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/da/83/15c4e30561a0d8c8d076c88cb159187823d877118f34c851ada3b9b02a7b/hypothesis-6.135.26.tar.gz", hash = "sha256:73af0e46cd5039c6806f514fed6a3c185d91ef88b5a1577477099ddbd1a2e300", size = 454523, upload-time = "2025-07-05T04:59:45.443Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ed/cb/c38acf27826a96712302229622f32dd356b9c4fbe52a3e9f615706027af8/hypothesis-6.135.24-py3-none-any.whl", hash = "sha256:88ed21fbfa481ca9851a9080841b3caca14cd4ed51a165dfae8006325775ee72", size = 520920, upload-time = "2025-07-03T02:46:48.286Z" },
+ { url = "https://files.pythonhosted.org/packages/3c/78/db4fdc464219455f8dde90074660c3faf8429101b2d1299cac7d219e3176/hypothesis-6.135.26-py3-none-any.whl", hash = "sha256:fa237cbe2ae2c31d65f7230dcb866139ace635dcfec6c30dddf25974dd8ff4b9", size = 521517, upload-time = "2025-07-05T04:59:42.061Z" },
]
[[package]]
@@ -2892,10 +2906,12 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9a/55/2cb24ea48aa30c99f805921c1c7860c1f45c0e811e44ee4e6a155668de06/lxml-6.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:219e0431ea8006e15005767f0351e3f7f9143e793e58519dc97fe9e07fae5563", size = 4952289, upload-time = "2025-06-28T18:47:25.602Z" },
{ url = "https://files.pythonhosted.org/packages/31/c0/b25d9528df296b9a3306ba21ff982fc5b698c45ab78b94d18c2d6ae71fd9/lxml-6.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bd5913b4972681ffc9718bc2d4c53cde39ef81415e1671ff93e9aa30b46595e7", size = 5111310, upload-time = "2025-06-28T18:47:28.136Z" },
{ url = "https://files.pythonhosted.org/packages/e9/af/681a8b3e4f668bea6e6514cbcb297beb6de2b641e70f09d3d78655f4f44c/lxml-6.0.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:390240baeb9f415a82eefc2e13285016f9c8b5ad71ec80574ae8fa9605093cd7", size = 5025457, upload-time = "2025-06-26T16:26:15.068Z" },
+ { url = "https://files.pythonhosted.org/packages/99/b6/3a7971aa05b7be7dfebc7ab57262ec527775c2c3c5b2f43675cac0458cad/lxml-6.0.0-cp312-cp312-manylinux_2_27_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d6e200909a119626744dd81bae409fc44134389e03fbf1d68ed2a55a2fb10991", size = 5657016, upload-time = "2025-07-03T19:19:06.008Z" },
{ url = "https://files.pythonhosted.org/packages/69/f8/693b1a10a891197143c0673fcce5b75fc69132afa81a36e4568c12c8faba/lxml-6.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ca50bd612438258a91b5b3788c6621c1f05c8c478e7951899f492be42defc0da", size = 5257565, upload-time = "2025-06-26T16:26:17.906Z" },
{ url = "https://files.pythonhosted.org/packages/a8/96/e08ff98f2c6426c98c8964513c5dab8d6eb81dadcd0af6f0c538ada78d33/lxml-6.0.0-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:c24b8efd9c0f62bad0439283c2c795ef916c5a6b75f03c17799775c7ae3c0c9e", size = 4713390, upload-time = "2025-06-26T16:26:20.292Z" },
{ url = "https://files.pythonhosted.org/packages/a8/83/6184aba6cc94d7413959f6f8f54807dc318fdcd4985c347fe3ea6937f772/lxml-6.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:afd27d8629ae94c5d863e32ab0e1d5590371d296b87dae0a751fb22bf3685741", size = 5066103, upload-time = "2025-06-26T16:26:22.765Z" },
{ url = "https://files.pythonhosted.org/packages/ee/01/8bf1f4035852d0ff2e36a4d9aacdbcc57e93a6cd35a54e05fa984cdf73ab/lxml-6.0.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:54c4855eabd9fc29707d30141be99e5cd1102e7d2258d2892314cf4c110726c3", size = 4791428, upload-time = "2025-06-26T16:26:26.461Z" },
+ { url = "https://files.pythonhosted.org/packages/29/31/c0267d03b16954a85ed6b065116b621d37f559553d9339c7dcc4943a76f1/lxml-6.0.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c907516d49f77f6cd8ead1322198bdfd902003c3c330c77a1c5f3cc32a0e4d16", size = 5678523, upload-time = "2025-07-03T19:19:09.837Z" },
{ url = "https://files.pythonhosted.org/packages/5c/f7/5495829a864bc5f8b0798d2b52a807c89966523140f3d6fa3a58ab6720ea/lxml-6.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:36531f81c8214e293097cd2b7873f178997dae33d3667caaae8bdfb9666b76c0", size = 5281290, upload-time = "2025-06-26T16:26:29.406Z" },
{ url = "https://files.pythonhosted.org/packages/79/56/6b8edb79d9ed294ccc4e881f4db1023af56ba451909b9ce79f2a2cd7c532/lxml-6.0.0-cp312-cp312-win32.whl", hash = "sha256:690b20e3388a7ec98e899fd54c924e50ba6693874aa65ef9cb53de7f7de9d64a", size = 3613495, upload-time = "2025-06-26T16:26:31.588Z" },
{ url = "https://files.pythonhosted.org/packages/0b/1e/cc32034b40ad6af80b6fd9b66301fc0f180f300002e5c3eb5a6110a93317/lxml-6.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:310b719b695b3dd442cdfbbe64936b2f2e231bb91d998e99e6f0daf991a3eba3", size = 4014711, upload-time = "2025-06-26T16:26:33.723Z" },
@@ -3732,7 +3748,7 @@ wheels = [
[[package]]
name = "opik"
-version = "1.7.41"
+version = "1.7.43"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "boto3-stubs", extra = ["bedrock-runtime"] },
@@ -3751,9 +3767,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "uuid6" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/82/81/6cddb705b3f416cfe4f0507916f51d0886087695f9dab49cfc6b00eb0266/opik-1.7.41.tar.gz", hash = "sha256:6ce2f72c7d23a62e2c13d419ce50754f6e17234825dcf26506e7def34dd38e26", size = 323333, upload-time = "2025-07-02T12:35:31.76Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/ba/52/cea0317bc3207bc967b48932781995d9cdb2c490e7e05caa00ff660f7205/opik-1.7.43.tar.gz", hash = "sha256:0b02522b0b74d0a67b141939deda01f8bb69690eda6b04a7cecb1c7f0649ccd0", size = 326886, upload-time = "2025-07-07T10:30:07.715Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/e9/46/ee27d06cc2049619806c992bdaa10e25b93d19ecedbc5c0fa772d8ac9a6d/opik-1.7.41-py3-none-any.whl", hash = "sha256:99df9c7b7b504777a51300b27a72bc646903201629611082b9b1f3c3adfbb3bf", size = 614890, upload-time = "2025-07-02T12:35:29.562Z" },
+ { url = "https://files.pythonhosted.org/packages/76/ae/f3566bdc3c49a1a8f795b1b6e726ef211c87e31f92d870ca6d63999c9bbf/opik-1.7.43-py3-none-any.whl", hash = "sha256:a66395c8b5ea7c24846f72dafc70c74d5b8f24ffbc4c8a1b3a7f9456e550568d", size = 625356, upload-time = "2025-07-07T10:30:06.389Z" },
]
[[package]]
@@ -3975,6 +3991,8 @@ sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/26/77f8ed17ca4ffd60e1dcd220a6ec6d71210ba398cfa33a13a1cd614c5613/pillow-11.3.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1cd110edf822773368b396281a2293aeb91c90a2db00d78ea43e7e861631b722", size = 5316531, upload-time = "2025-07-01T09:13:59.203Z" },
{ url = "https://files.pythonhosted.org/packages/cb/39/ee475903197ce709322a17a866892efb560f57900d9af2e55f86db51b0a5/pillow-11.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9c412fddd1b77a75aa904615ebaa6001f169b26fd467b4be93aded278266b288", size = 4686560, upload-time = "2025-07-01T09:14:01.101Z" },
+ { url = "https://files.pythonhosted.org/packages/d5/90/442068a160fd179938ba55ec8c97050a612426fae5ec0a764e345839f76d/pillow-11.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1aa4de119a0ecac0a34a9c8bde33f34022e2e8f99104e47a3ca392fd60e37d", size = 5870978, upload-time = "2025-07-03T13:09:55.638Z" },
+ { url = "https://files.pythonhosted.org/packages/13/92/dcdd147ab02daf405387f0218dcf792dc6dd5b14d2573d40b4caeef01059/pillow-11.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:91da1d88226663594e3f6b4b8c3c8d85bd504117d043740a8e0ec449087cc494", size = 7641168, upload-time = "2025-07-03T13:10:00.37Z" },
{ url = "https://files.pythonhosted.org/packages/6e/db/839d6ba7fd38b51af641aa904e2960e7a5644d60ec754c046b7d2aee00e5/pillow-11.3.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:643f189248837533073c405ec2f0bb250ba54598cf80e8c1e043381a60632f58", size = 5973053, upload-time = "2025-07-01T09:14:04.491Z" },
{ url = "https://files.pythonhosted.org/packages/f2/2f/d7675ecae6c43e9f12aa8d58b6012683b20b6edfbdac7abcb4e6af7a3784/pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:106064daa23a745510dabce1d84f29137a37224831d88eb4ce94bb187b1d7e5f", size = 6640273, upload-time = "2025-07-01T09:14:06.235Z" },
{ url = "https://files.pythonhosted.org/packages/45/ad/931694675ede172e15b2ff03c8144a0ddaea1d87adb72bb07655eaffb654/pillow-11.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cd8ff254faf15591e724dc7c4ddb6bf4793efcbe13802a4ae3e863cd300b493e", size = 6082043, upload-time = "2025-07-01T09:14:07.978Z" },
@@ -3984,6 +4002,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c6/df/90bd886fabd544c25addd63e5ca6932c86f2b701d5da6c7839387a076b4a/pillow-11.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:30807c931ff7c095620fe04448e2c2fc673fcbb1ffe2a7da3fb39613489b1ddd", size = 2423079, upload-time = "2025-07-01T09:14:15.268Z" },
{ url = "https://files.pythonhosted.org/packages/40/fe/1bc9b3ee13f68487a99ac9529968035cca2f0a51ec36892060edcc51d06a/pillow-11.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdae223722da47b024b867c1ea0be64e0df702c5e0a60e27daad39bf960dd1e4", size = 5278800, upload-time = "2025-07-01T09:14:17.648Z" },
{ url = "https://files.pythonhosted.org/packages/2c/32/7e2ac19b5713657384cec55f89065fb306b06af008cfd87e572035b27119/pillow-11.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:921bd305b10e82b4d1f5e802b6850677f965d8394203d182f078873851dada69", size = 4686296, upload-time = "2025-07-01T09:14:19.828Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/1e/b9e12bbe6e4c2220effebc09ea0923a07a6da1e1f1bfbc8d7d29a01ce32b/pillow-11.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:eb76541cba2f958032d79d143b98a3a6b3ea87f0959bbe256c0b5e416599fd5d", size = 5871726, upload-time = "2025-07-03T13:10:04.448Z" },
+ { url = "https://files.pythonhosted.org/packages/8d/33/e9200d2bd7ba00dc3ddb78df1198a6e80d7669cce6c2bdbeb2530a74ec58/pillow-11.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:67172f2944ebba3d4a7b54f2e95c786a3a50c21b88456329314caaa28cda70f6", size = 7644652, upload-time = "2025-07-03T13:10:10.391Z" },
{ url = "https://files.pythonhosted.org/packages/41/f1/6f2427a26fc683e00d985bc391bdd76d8dd4e92fac33d841127eb8fb2313/pillow-11.3.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97f07ed9f56a3b9b5f49d3661dc9607484e85c67e27f3e8be2c7d28ca032fec7", size = 5977787, upload-time = "2025-07-01T09:14:21.63Z" },
{ url = "https://files.pythonhosted.org/packages/e4/c9/06dd4a38974e24f932ff5f98ea3c546ce3f8c995d3f0985f8e5ba48bba19/pillow-11.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:676b2815362456b5b3216b4fd5bd89d362100dc6f4945154ff172e206a22c024", size = 6645236, upload-time = "2025-07-01T09:14:23.321Z" },
{ url = "https://files.pythonhosted.org/packages/40/e7/848f69fb79843b3d91241bad658e9c14f39a32f71a301bcd1d139416d1be/pillow-11.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3e184b2f26ff146363dd07bde8b711833d7b0202e27d13540bfe2e35a323a809", size = 6086950, upload-time = "2025-07-01T09:14:25.237Z" },
@@ -3993,6 +4013,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/16/8f/b13447d1bf0b1f7467ce7d86f6e6edf66c0ad7cf44cf5c87a37f9bed9936/pillow-11.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:2aceea54f957dd4448264f9bf40875da0415c83eb85f55069d89c0ed436e3542", size = 2423067, upload-time = "2025-07-01T09:14:33.709Z" },
{ url = "https://files.pythonhosted.org/packages/9e/e3/6fa84033758276fb31da12e5fb66ad747ae83b93c67af17f8c6ff4cc8f34/pillow-11.3.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7c8ec7a017ad1bd562f93dbd8505763e688d388cde6e4a010ae1486916e713e6", size = 5270566, upload-time = "2025-07-01T09:16:19.801Z" },
{ url = "https://files.pythonhosted.org/packages/5b/ee/e8d2e1ab4892970b561e1ba96cbd59c0d28cf66737fc44abb2aec3795a4e/pillow-11.3.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:9ab6ae226de48019caa8074894544af5b53a117ccb9d3b3dcb2871464c829438", size = 4654618, upload-time = "2025-07-01T09:16:21.818Z" },
+ { url = "https://files.pythonhosted.org/packages/f2/6d/17f80f4e1f0761f02160fc433abd4109fa1548dcfdca46cfdadaf9efa565/pillow-11.3.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fe27fb049cdcca11f11a7bfda64043c37b30e6b91f10cb5bab275806c32f6ab3", size = 4874248, upload-time = "2025-07-03T13:11:20.738Z" },
+ { url = "https://files.pythonhosted.org/packages/de/5f/c22340acd61cef960130585bbe2120e2fd8434c214802f07e8c03596b17e/pillow-11.3.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:465b9e8844e3c3519a983d58b80be3f668e2a7a5db97f2784e7079fbc9f9822c", size = 6583963, upload-time = "2025-07-03T13:11:26.283Z" },
{ url = "https://files.pythonhosted.org/packages/31/5e/03966aedfbfcbb4d5f8aa042452d3361f325b963ebbadddac05b122e47dd/pillow-11.3.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5418b53c0d59b3824d05e029669efa023bbef0f3e92e75ec8428f3799487f361", size = 4957170, upload-time = "2025-07-01T09:16:23.762Z" },
{ url = "https://files.pythonhosted.org/packages/cc/2d/e082982aacc927fc2cab48e1e731bdb1643a1406acace8bed0900a61464e/pillow-11.3.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:504b6f59505f08ae014f724b6207ff6222662aab5cc9542577fb084ed0676ac7", size = 5581505, upload-time = "2025-07-01T09:16:25.593Z" },
{ url = "https://files.pythonhosted.org/packages/34/e7/ae39f538fd6844e982063c3a5e4598b8ced43b9633baa3a85ef33af8c05c/pillow-11.3.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:c84d689db21a1c397d001aa08241044aa2069e7587b398c8cc63020390b1c1b8", size = 6984598, upload-time = "2025-07-01T09:16:27.732Z" },
@@ -4065,7 +4087,7 @@ wheels = [
[[package]]
name = "posthog"
-version = "6.0.2"
+version = "6.0.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "backoff" },
@@ -4075,9 +4097,9 @@ dependencies = [
{ name = "six" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/d9/10/37ea988b3ae73cbfd1f2d5e523cca31cecfcc40cbd0de6511f40462fdb78/posthog-6.0.2.tar.gz", hash = "sha256:94a28e65d7a2d1b2952e53a1b97fa4d6504b8d7e4c197c57f653621e55b549eb", size = 88141, upload-time = "2025-07-02T19:21:50.306Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/39/a2/1b68562124b0d0e615fa8431cc88c84b3db6526275c2c19a419579a49277/posthog-6.0.3.tar.gz", hash = "sha256:9005abb341af8fedd9d82ca0359b3d35a9537555cdc9881bfb469f7c0b4b0ec5", size = 91861, upload-time = "2025-07-07T07:14:08.21Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/85/2c/0c5dbbf9bc30401ae2a1b6b52b8abc19e4060cf28c3288ae9d962e65e3ad/posthog-6.0.2-py3-none-any.whl", hash = "sha256:756cc9adad9e42961454f8ac391b92a2f70ebb6607d29b0c568de08e5d8f1b18", size = 104946, upload-time = "2025-07-02T19:21:48.77Z" },
+ { url = "https://files.pythonhosted.org/packages/ca/f1/a8d86245d41c8686f7d828a4959bdf483e8ac331b249b48b8c61fc884a1c/posthog-6.0.3-py3-none-any.whl", hash = "sha256:4b808c907f3623216a9362d91fdafce8e2f57a8387fb3020475c62ec809be56d", size = 108978, upload-time = "2025-07-07T07:14:06.451Z" },
]
[[package]]
@@ -4585,39 +4607,39 @@ wheels = [
[[package]]
name = "python-calamine"
-version = "0.3.2"
+version = "0.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "packaging" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/6b/21/387b92059909e741af7837194d84250335d2a057f614752b6364aaaa2f56/python_calamine-0.3.2.tar.gz", hash = "sha256:5cf12f2086373047cdea681711857b672cba77a34a66dd3755d60686fc974e06", size = 117336, upload-time = "2025-04-02T10:06:23.14Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/cc/03/269f96535705b2f18c8977fa58e76763b4e4727a9b3ae277a9468c8ffe05/python_calamine-0.4.0.tar.gz", hash = "sha256:94afcbae3fec36d2d7475095a59d4dc6fae45829968c743cb799ebae269d7bbf", size = 127737, upload-time = "2025-07-04T06:05:28.626Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ef/b7/d59863ebe319150739d0c352c6dea2710a2f90254ed32304d52e8349edce/python_calamine-0.3.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5251746816069c38eafdd1e4eb7b83870e1fe0ff6191ce9a809b187ffba8ce93", size = 830854, upload-time = "2025-04-02T10:04:14.673Z" },
- { url = "https://files.pythonhosted.org/packages/d3/01/b48c6f2c2e530a1a031199c5c5bf35f7c2cf7f16f3989263e616e3bc86ce/python_calamine-0.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9775dbc93bc635d48f45433f8869a546cca28c2a86512581a05333f97a18337b", size = 809411, upload-time = "2025-04-02T10:04:16.067Z" },
- { url = "https://files.pythonhosted.org/packages/fe/6d/69c53ffb11b3ee1bf5bd945cc2514848adea492c879a50f38e2ed4424727/python_calamine-0.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ff4318b72ba78e8a04fb4c45342cfa23eab6f81ecdb85548cdab9f2db8ac9c7", size = 872905, upload-time = "2025-04-02T10:04:17.487Z" },
- { url = "https://files.pythonhosted.org/packages/be/ec/b02c4bc04c426d153af1f5ff07e797dd81ada6f47c170e0207d07c90b53a/python_calamine-0.3.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0cd8eb1ef8644da71788a33d3de602d1c08ff1c4136942d87e25f09580b512ef", size = 876464, upload-time = "2025-04-02T10:04:19.53Z" },
- { url = "https://files.pythonhosted.org/packages/46/ef/8403ee595207de5bd277279b56384b31390987df8a61c280b4176802481a/python_calamine-0.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9dcfd560d8f88f39d23b829f666ebae4bd8daeec7ed57adfb9313543f3c5fa35", size = 942289, upload-time = "2025-04-02T10:04:20.902Z" },
- { url = "https://files.pythonhosted.org/packages/89/97/b4e5b77c70b36613c10f2dbeece75b5d43727335a33bf5176792ec83c3fc/python_calamine-0.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5e79b9eae4b30c82d045f9952314137c7089c88274e1802947f9e3adb778a59", size = 978699, upload-time = "2025-04-02T10:04:22.263Z" },
- { url = "https://files.pythonhosted.org/packages/5f/e9/03bbafd6b11cdf70c004f2e856978fc252ec5ea7e77529f14f969134c7a8/python_calamine-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce5e8cc518c8e3e5988c5c658f9dcd8229f5541ca63353175bb15b6ad8c456d0", size = 886008, upload-time = "2025-04-02T10:04:23.754Z" },
- { url = "https://files.pythonhosted.org/packages/7b/20/e18f534e49b403ba0b979a4dfead146001d867f5be846b91f81ed5377972/python_calamine-0.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2a0e596b1346c28b2de15c9f86186cceefa4accb8882992aa0b7499c593446ed", size = 925104, upload-time = "2025-04-02T10:04:25.255Z" },
- { url = "https://files.pythonhosted.org/packages/54/4c/58933e69a0a7871487d10b958c1f83384bc430d53efbbfbf1dea141a0d85/python_calamine-0.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f521de16a9f3e951ec2e5e35d76752fe004088dbac4cdbf4dd62d0ad2bbf650f", size = 1050448, upload-time = "2025-04-02T10:04:26.649Z" },
- { url = "https://files.pythonhosted.org/packages/83/95/5c96d093eaaa2d15c63b43bcf8c87708eaab8428c72b6ebdcafc2604aa47/python_calamine-0.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:417d6825a36bba526ae17bed1b6ca576fbb54e23dc60c97eeb536c622e77c62f", size = 1056840, upload-time = "2025-04-02T10:04:28.18Z" },
- { url = "https://files.pythonhosted.org/packages/23/e0/b03cc3ad4f40fd3be0ebac0b71d273864ddf2bf0e611ec309328fdedded9/python_calamine-0.3.2-cp311-cp311-win32.whl", hash = "sha256:cd3ea1ca768139753633f9f0b16997648db5919894579f363d71f914f85f7ade", size = 663268, upload-time = "2025-04-02T10:04:29.659Z" },
- { url = "https://files.pythonhosted.org/packages/6b/bd/550da64770257fc70a185482f6353c0654a11f381227e146bb0170db040f/python_calamine-0.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:4560100412d8727c49048cca102eadeb004f91cfb9c99ae63cd7d4dc0a61333a", size = 692393, upload-time = "2025-04-02T10:04:31.534Z" },
- { url = "https://files.pythonhosted.org/packages/be/2e/0b4b7a146c3bb41116fe8e59a2f616340786db12aed51c7a9e75817cfa03/python_calamine-0.3.2-cp311-cp311-win_arm64.whl", hash = "sha256:a2526e6ba79087b1634f49064800339edb7316780dd7e1e86d10a0ca9de4e90f", size = 667312, upload-time = "2025-04-02T10:04:32.911Z" },
- { url = "https://files.pythonhosted.org/packages/f2/0f/c2e3e3bae774dae47cba6ffa640ff95525bd6a10a13d3cd998f33aeafc7f/python_calamine-0.3.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7c063b1f783352d6c6792305b2b0123784882e2436b638a9b9a1e97f6d74fa51", size = 825179, upload-time = "2025-04-02T10:04:34.377Z" },
- { url = "https://files.pythonhosted.org/packages/c7/81/a05285f06d71ea38ab99b09f3119f93f575487c9d24d7a1bab65657b258b/python_calamine-0.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:85016728937e8f5d1810ff3c9603ffd2458d66e34d495202d7759fa8219871cd", size = 804036, upload-time = "2025-04-02T10:04:35.938Z" },
- { url = "https://files.pythonhosted.org/packages/24/b5/320f366ffd91ee5d5f0f77817d4fb684f62a5a68e438dcdb90e4f5f35137/python_calamine-0.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81f243323bf712bb0b2baf0b938a2e6d6c9fa3b9902a44c0654474d04f999fac", size = 871527, upload-time = "2025-04-02T10:04:38.272Z" },
- { url = "https://files.pythonhosted.org/packages/13/19/063afced19620b829697b90329c62ad73274cc38faaa91d9ee41047f5f8c/python_calamine-0.3.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0b719dd2b10237b0cfb2062e3eaf199f220918a5623197e8449f37c8de845a7c", size = 875411, upload-time = "2025-04-02T10:04:39.647Z" },
- { url = "https://files.pythonhosted.org/packages/d7/6a/c93c52414ec62cc51c4820aff434f03c4a1c69ced15cec3e4b93885e4012/python_calamine-0.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d5158310b9140e8ee8665c9541a11030901e7275eb036988150c93f01c5133bf", size = 943525, upload-time = "2025-04-02T10:04:41.025Z" },
- { url = "https://files.pythonhosted.org/packages/0a/0a/5bdecee03d235e8d111b1e8ee3ea0c0ed4ae43a402f75cebbe719930cf04/python_calamine-0.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2c1b248e8bf10194c449cb57e6ccb3f2fe3dc86975a6d746908cf2d37b048cc", size = 976332, upload-time = "2025-04-02T10:04:42.454Z" },
- { url = "https://files.pythonhosted.org/packages/05/ad/43ff92366856ee34f958e9cf4f5b98e63b0dc219e06ccba4ad6f63463756/python_calamine-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a13ad8e5b6843a73933b8d1710bc4df39a9152cb57c11227ad51f47b5838a4", size = 885549, upload-time = "2025-04-02T10:04:43.869Z" },
- { url = "https://files.pythonhosted.org/packages/ff/b9/76afb867e2bb4bfc296446b741cee01ae4ce6a094b43f4ed4eaed5189de4/python_calamine-0.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fe950975a5758423c982ce1e2fdcb5c9c664d1a20b41ea21e619e5003bb4f96b", size = 926005, upload-time = "2025-04-02T10:04:45.884Z" },
- { url = "https://files.pythonhosted.org/packages/23/cf/5252b237b0e70c263f86741aea02e8e57aedb2bce9898468be1d9d55b9da/python_calamine-0.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:8707622ba816d6c26e36f1506ecda66a6a6cf43e55a43a8ef4c3bf8a805d3cfb", size = 1049380, upload-time = "2025-04-02T10:04:49.202Z" },
- { url = "https://files.pythonhosted.org/packages/1a/4d/f151e8923e53457ca49ceeaa3a34cb23afee7d7b46e6546ab2a29adc9125/python_calamine-0.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e6eac46475c26e162a037f6711b663767f61f8fca3daffeb35aa3fc7ee6267cc", size = 1056720, upload-time = "2025-04-02T10:04:51.002Z" },
- { url = "https://files.pythonhosted.org/packages/f5/cb/1b5db3e4a8bbaaaa7706b270570d4a65133618fa0ca7efafe5ce680f6cee/python_calamine-0.3.2-cp312-cp312-win32.whl", hash = "sha256:0dee82aedef3db27368a388d6741d69334c1d4d7a8087ddd33f1912166e17e37", size = 663502, upload-time = "2025-04-02T10:04:52.402Z" },
- { url = "https://files.pythonhosted.org/packages/5a/53/920fa8e7b570647c08da0f1158d781db2e318918b06cb28fe0363c3398ac/python_calamine-0.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:ae09b779718809d31ca5d722464be2776b7d79278b1da56e159bbbe11880eecf", size = 692660, upload-time = "2025-04-02T10:04:53.721Z" },
- { url = "https://files.pythonhosted.org/packages/a5/ea/5d0ecf5c345c4d78964a5f97e61848bc912965b276a54fb8ae698a9419a8/python_calamine-0.3.2-cp312-cp312-win_arm64.whl", hash = "sha256:435546e401a5821fa70048b6c03a70db3b27d00037e2c4999c2126d8c40b51df", size = 666205, upload-time = "2025-04-02T10:04:56.377Z" },
+ { url = "https://files.pythonhosted.org/packages/d4/a5/bcd82326d0ff1ab5889e7a5e13c868b483fc56398e143aae8e93149ba43b/python_calamine-0.4.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d1687f8c4d7852920c7b4e398072f183f88dd273baf5153391edc88b7454b8c0", size = 833019, upload-time = "2025-07-04T06:03:32.214Z" },
+ { url = "https://files.pythonhosted.org/packages/f6/1a/a681f1d2f28164552e91ef47bcde6708098aa64a5f5fe3952f22362d340a/python_calamine-0.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:258d04230bebbbafa370a15838049d912d6a0a2c4da128943d8160ca4b6db58e", size = 812268, upload-time = "2025-07-04T06:03:33.855Z" },
+ { url = "https://files.pythonhosted.org/packages/3d/92/2fc911431733739d4e7a633cefa903fa49a6b7a61e8765bad29a4a7c47b1/python_calamine-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c686e491634934f059553d55f77ac67ca4c235452d5b444f98fe79b3579f1ea5", size = 875733, upload-time = "2025-07-04T06:03:35.154Z" },
+ { url = "https://files.pythonhosted.org/packages/f4/f0/48bfae6802eb360028ca6c15e9edf42243aadd0006b6ac3e9edb41a57119/python_calamine-0.4.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4480af7babcc2f919c638a554b06b7b145d9ab3da47fd696d68c2fc6f67f9541", size = 878325, upload-time = "2025-07-04T06:03:36.638Z" },
+ { url = "https://files.pythonhosted.org/packages/a4/dc/f8c956e15bac9d5d1e05cd1b907ae780e40522d2fd103c8c6e2f21dff4ed/python_calamine-0.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e405b87a8cd1e90a994e570705898634f105442029f25bab7da658ee9cbaa771", size = 1015038, upload-time = "2025-07-04T06:03:37.971Z" },
+ { url = "https://files.pythonhosted.org/packages/54/3f/e69ab97c7734fb850fba2f506b775912fd59f04e17488582c8fbf52dbc72/python_calamine-0.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a831345ee42615f0dfcb0ed60a3b1601d2f946d4166edae64fd9a6f9bbd57fc1", size = 924969, upload-time = "2025-07-04T06:03:39.253Z" },
+ { url = "https://files.pythonhosted.org/packages/79/03/b4c056b468908d87a3de94389166e0f4dba725a70bc39e03bc039ba96f6b/python_calamine-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9951b8e4cafb3e1623bb5dfc31a18d38ef43589275f9657e99dfcbe4c8c4b33e", size = 888020, upload-time = "2025-07-04T06:03:41.099Z" },
+ { url = "https://files.pythonhosted.org/packages/86/4f/b9092f7c970894054083656953184e44cb2dadff8852425e950d4ca419af/python_calamine-0.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a6619fe3b5c9633ed8b178684605f8076c9d8d85b29ade15f7a7713fcfdee2d0", size = 930337, upload-time = "2025-07-04T06:03:42.89Z" },
+ { url = "https://files.pythonhosted.org/packages/64/da/137239027bf253aabe7063450950085ec9abd827d0cbc5170f585f38f464/python_calamine-0.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2cc45b8e76ee331f6ea88ca23677be0b7a05b502cd4423ba2c2bc8dad53af1be", size = 1054568, upload-time = "2025-07-04T06:03:44.153Z" },
+ { url = "https://files.pythonhosted.org/packages/80/96/74c38bcf6b6825d5180c0e147b85be8c52dbfba11848b1e98ba358e32a64/python_calamine-0.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1b2cfb7ced1a7c80befa0cfddfe4aae65663eb4d63c4ae484b9b7a80ebe1b528", size = 1058317, upload-time = "2025-07-04T06:03:45.873Z" },
+ { url = "https://files.pythonhosted.org/packages/33/95/9d7b8fe8b32d99a6c79534df3132cfe40e9df4a0f5204048bf5e66ddbd93/python_calamine-0.4.0-cp311-cp311-win32.whl", hash = "sha256:04f4e32ee16814fc1fafc49300be8eeb280d94878461634768b51497e1444bd6", size = 663934, upload-time = "2025-07-04T06:03:47.407Z" },
+ { url = "https://files.pythonhosted.org/packages/7c/e3/1c6cd9fd499083bea6ff1c30033ee8215b9f64e862babf5be170cacae190/python_calamine-0.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:a8543f69afac2213c0257bb56215b03dadd11763064a9d6b19786f27d1bef586", size = 692535, upload-time = "2025-07-04T06:03:48.699Z" },
+ { url = "https://files.pythonhosted.org/packages/94/1c/3105d19fbab6b66874ce8831652caedd73b23b72e88ce18addf8ceca8c12/python_calamine-0.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:54622e35ec7c3b6f07d119da49aa821731c185e951918f152c2dbf3bec1e15d6", size = 671751, upload-time = "2025-07-04T06:03:49.979Z" },
+ { url = "https://files.pythonhosted.org/packages/63/60/f951513aaaa470b3a38a87d65eca45e0a02bc329b47864f5a17db563f746/python_calamine-0.4.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:74bca5d44a73acf3dcfa5370820797fcfd225c8c71abcddea987c5b4f5077e98", size = 826603, upload-time = "2025-07-04T06:03:51.245Z" },
+ { url = "https://files.pythonhosted.org/packages/76/3f/789955bbc77831c639890758f945eb2b25d6358065edf00da6751226cf31/python_calamine-0.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cf80178f5d1b0ee2ccfffb8549c50855f6249e930664adc5807f4d0d6c2b269c", size = 805826, upload-time = "2025-07-04T06:03:52.482Z" },
+ { url = "https://files.pythonhosted.org/packages/00/4c/f87d17d996f647030a40bfd124fe45fe893c002bee35ae6aca9910a923ae/python_calamine-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65cfef345386ae86f7720f1be93495a40fd7e7feabb8caa1df5025d7fbc58a1f", size = 874989, upload-time = "2025-07-04T06:03:53.794Z" },
+ { url = "https://files.pythonhosted.org/packages/47/d2/3269367303f6c0488cf1bfebded3f9fe968d118a988222e04c9b2636bf2e/python_calamine-0.4.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f23e6214dbf9b29065a5dcfd6a6c674dd0e251407298c9138611c907d53423ff", size = 877504, upload-time = "2025-07-04T06:03:55.095Z" },
+ { url = "https://files.pythonhosted.org/packages/f9/6d/c7ac35f5c7125e8bd07eb36773f300fda20dd2da635eae78a8cebb0b6ab7/python_calamine-0.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d792d304ee232ab01598e1d3ab22e074a32c2511476b5fb4f16f4222d9c2a265", size = 1014171, upload-time = "2025-07-04T06:03:56.777Z" },
+ { url = "https://files.pythonhosted.org/packages/f0/81/5ea8792a2e9ab5e2a05872db3a4d3ed3538ad5af1861282c789e2f13a8cf/python_calamine-0.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bf813425918fd68f3e991ef7c4b5015be0a1a95fc4a8ab7e73c016ef1b881bb4", size = 926737, upload-time = "2025-07-04T06:03:58.024Z" },
+ { url = "https://files.pythonhosted.org/packages/cc/6e/989e56e6f073fc0981a74ba7a393881eb351bb143e5486aa629b5e5d6a8b/python_calamine-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbe2a0ccb4d003635888eea83a995ff56b0748c8c76fc71923544f5a4a7d4cd7", size = 887032, upload-time = "2025-07-04T06:03:59.298Z" },
+ { url = "https://files.pythonhosted.org/packages/5d/92/2c9bd64277c6fe4be695d7d5a803b38d953ec8565037486be7506642c27c/python_calamine-0.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a7b3bb5f0d910b9b03c240987560f843256626fd443279759df4e91b717826d2", size = 929700, upload-time = "2025-07-04T06:04:01.388Z" },
+ { url = "https://files.pythonhosted.org/packages/64/fa/fc758ca37701d354a6bc7d63118699f1c73788a1f2e1b44d720824992764/python_calamine-0.4.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bd2c0fc2b5eabd08ceac8a2935bffa88dbc6116db971aa8c3f244bad3fd0f644", size = 1053971, upload-time = "2025-07-04T06:04:02.704Z" },
+ { url = "https://files.pythonhosted.org/packages/65/52/40d7e08ae0ddba331cdc9f7fb3e92972f8f38d7afbd00228158ff6d1fceb/python_calamine-0.4.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:85b547cb1c5b692a0c2406678d666dbc1cec65a714046104683fe4f504a1721d", size = 1057057, upload-time = "2025-07-04T06:04:04.014Z" },
+ { url = "https://files.pythonhosted.org/packages/16/de/e8a071c0adfda73285d891898a24f6e99338328c404f497ff5b0e6bc3d45/python_calamine-0.4.0-cp312-cp312-win32.whl", hash = "sha256:4c2a1e3a0db4d6de4587999a21cc35845648c84fba81c03dd6f3072c690888e4", size = 665540, upload-time = "2025-07-04T06:04:05.679Z" },
+ { url = "https://files.pythonhosted.org/packages/5e/f2/7fdfada13f80db12356853cf08697ff4e38800a1809c2bdd26ee60962e7a/python_calamine-0.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b193c89ffcc146019475cd121c552b23348411e19c04dedf5c766a20db64399a", size = 695366, upload-time = "2025-07-04T06:04:06.977Z" },
+ { url = "https://files.pythonhosted.org/packages/20/66/d37412ad854480ce32f50d9f74f2a2f88b1b8a6fbc32f70aabf3211ae89e/python_calamine-0.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:43a0f15e0b60c75a71b21a012b911d5d6f5fa052afad2a8edbc728af43af0fcf", size = 670740, upload-time = "2025-07-04T06:04:08.656Z" },
]
[[package]]
@@ -5066,27 +5088,27 @@ wheels = [
[[package]]
name = "ruff"
-version = "0.11.13"
+version = "0.12.3"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/ed/da/9c6f995903b4d9474b39da91d2d626659af3ff1eeb43e9ae7c119349dba6/ruff-0.11.13.tar.gz", hash = "sha256:26fa247dc68d1d4e72c179e08889a25ac0c7ba4d78aecfc835d49cbfd60bf514", size = 4282054, upload-time = "2025-06-05T21:00:15.721Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/c3/2a/43955b530c49684d3c38fcda18c43caf91e99204c2a065552528e0552d4f/ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77", size = 4459341, upload-time = "2025-07-11T13:21:16.086Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/7d/ce/a11d381192966e0b4290842cc8d4fac7dc9214ddf627c11c1afff87da29b/ruff-0.11.13-py3-none-linux_armv6l.whl", hash = "sha256:4bdfbf1240533f40042ec00c9e09a3aade6f8c10b6414cf11b519488d2635d46", size = 10292516, upload-time = "2025-06-05T20:59:32.944Z" },
- { url = "https://files.pythonhosted.org/packages/78/db/87c3b59b0d4e753e40b6a3b4a2642dfd1dcaefbff121ddc64d6c8b47ba00/ruff-0.11.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aef9c9ed1b5ca28bb15c7eac83b8670cf3b20b478195bd49c8d756ba0a36cf48", size = 11106083, upload-time = "2025-06-05T20:59:37.03Z" },
- { url = "https://files.pythonhosted.org/packages/77/79/d8cec175856ff810a19825d09ce700265f905c643c69f45d2b737e4a470a/ruff-0.11.13-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53b15a9dfdce029c842e9a5aebc3855e9ab7771395979ff85b7c1dedb53ddc2b", size = 10436024, upload-time = "2025-06-05T20:59:39.741Z" },
- { url = "https://files.pythonhosted.org/packages/8b/5b/f6d94f2980fa1ee854b41568368a2e1252681b9238ab2895e133d303538f/ruff-0.11.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab153241400789138d13f362c43f7edecc0edfffce2afa6a68434000ecd8f69a", size = 10646324, upload-time = "2025-06-05T20:59:42.185Z" },
- { url = "https://files.pythonhosted.org/packages/6c/9c/b4c2acf24ea4426016d511dfdc787f4ce1ceb835f3c5fbdbcb32b1c63bda/ruff-0.11.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c51f93029d54a910d3d24f7dd0bb909e31b6cd989a5e4ac513f4eb41629f0dc", size = 10174416, upload-time = "2025-06-05T20:59:44.319Z" },
- { url = "https://files.pythonhosted.org/packages/f3/10/e2e62f77c65ede8cd032c2ca39c41f48feabedb6e282bfd6073d81bb671d/ruff-0.11.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1808b3ed53e1a777c2ef733aca9051dc9bf7c99b26ece15cb59a0320fbdbd629", size = 11724197, upload-time = "2025-06-05T20:59:46.935Z" },
- { url = "https://files.pythonhosted.org/packages/bb/f0/466fe8469b85c561e081d798c45f8a1d21e0b4a5ef795a1d7f1a9a9ec182/ruff-0.11.13-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d28ce58b5ecf0f43c1b71edffabe6ed7f245d5336b17805803312ec9bc665933", size = 12511615, upload-time = "2025-06-05T20:59:49.534Z" },
- { url = "https://files.pythonhosted.org/packages/17/0e/cefe778b46dbd0cbcb03a839946c8f80a06f7968eb298aa4d1a4293f3448/ruff-0.11.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55e4bc3a77842da33c16d55b32c6cac1ec5fb0fbec9c8c513bdce76c4f922165", size = 12117080, upload-time = "2025-06-05T20:59:51.654Z" },
- { url = "https://files.pythonhosted.org/packages/5d/2c/caaeda564cbe103bed145ea557cb86795b18651b0f6b3ff6a10e84e5a33f/ruff-0.11.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:633bf2c6f35678c56ec73189ba6fa19ff1c5e4807a78bf60ef487b9dd272cc71", size = 11326315, upload-time = "2025-06-05T20:59:54.469Z" },
- { url = "https://files.pythonhosted.org/packages/75/f0/782e7d681d660eda8c536962920c41309e6dd4ebcea9a2714ed5127d44bd/ruff-0.11.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ffbc82d70424b275b089166310448051afdc6e914fdab90e08df66c43bb5ca9", size = 11555640, upload-time = "2025-06-05T20:59:56.986Z" },
- { url = "https://files.pythonhosted.org/packages/5d/d4/3d580c616316c7f07fb3c99dbecfe01fbaea7b6fd9a82b801e72e5de742a/ruff-0.11.13-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a9ddd3ec62a9a89578c85842b836e4ac832d4a2e0bfaad3b02243f930ceafcc", size = 10507364, upload-time = "2025-06-05T20:59:59.154Z" },
- { url = "https://files.pythonhosted.org/packages/5a/dc/195e6f17d7b3ea6b12dc4f3e9de575db7983db187c378d44606e5d503319/ruff-0.11.13-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d237a496e0778d719efb05058c64d28b757c77824e04ffe8796c7436e26712b7", size = 10141462, upload-time = "2025-06-05T21:00:01.481Z" },
- { url = "https://files.pythonhosted.org/packages/f4/8e/39a094af6967faa57ecdeacb91bedfb232474ff8c3d20f16a5514e6b3534/ruff-0.11.13-py3-none-musllinux_1_2_i686.whl", hash = "sha256:26816a218ca6ef02142343fd24c70f7cd8c5aa6c203bca284407adf675984432", size = 11121028, upload-time = "2025-06-05T21:00:04.06Z" },
- { url = "https://files.pythonhosted.org/packages/5a/c0/b0b508193b0e8a1654ec683ebab18d309861f8bd64e3a2f9648b80d392cb/ruff-0.11.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:51c3f95abd9331dc5b87c47ac7f376db5616041173826dfd556cfe3d4977f492", size = 11602992, upload-time = "2025-06-05T21:00:06.249Z" },
- { url = "https://files.pythonhosted.org/packages/7c/91/263e33ab93ab09ca06ce4f8f8547a858cc198072f873ebc9be7466790bae/ruff-0.11.13-py3-none-win32.whl", hash = "sha256:96c27935418e4e8e77a26bb05962817f28b8ef3843a6c6cc49d8783b5507f250", size = 10474944, upload-time = "2025-06-05T21:00:08.459Z" },
- { url = "https://files.pythonhosted.org/packages/46/f4/7c27734ac2073aae8efb0119cae6931b6fb48017adf048fdf85c19337afc/ruff-0.11.13-py3-none-win_amd64.whl", hash = "sha256:29c3189895a8a6a657b7af4e97d330c8a3afd2c9c8f46c81e2fc5a31866517e3", size = 11548669, upload-time = "2025-06-05T21:00:11.147Z" },
- { url = "https://files.pythonhosted.org/packages/ec/bf/b273dd11673fed8a6bd46032c0ea2a04b2ac9bfa9c628756a5856ba113b0/ruff-0.11.13-py3-none-win_arm64.whl", hash = "sha256:b4385285e9179d608ff1d2fb9922062663c658605819a6876d8beef0c30b7f3b", size = 10683928, upload-time = "2025-06-05T21:00:13.758Z" },
+ { url = "https://files.pythonhosted.org/packages/e2/fd/b44c5115539de0d598d75232a1cc7201430b6891808df111b8b0506aae43/ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2", size = 10430499, upload-time = "2025-07-11T13:20:26.321Z" },
+ { url = "https://files.pythonhosted.org/packages/43/c5/9eba4f337970d7f639a37077be067e4ec80a2ad359e4cc6c5b56805cbc66/ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041", size = 11213413, upload-time = "2025-07-11T13:20:30.017Z" },
+ { url = "https://files.pythonhosted.org/packages/e2/2c/fac3016236cf1fe0bdc8e5de4f24c76ce53c6dd9b5f350d902549b7719b2/ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882", size = 10586941, upload-time = "2025-07-11T13:20:33.046Z" },
+ { url = "https://files.pythonhosted.org/packages/c5/0f/41fec224e9dfa49a139f0b402ad6f5d53696ba1800e0f77b279d55210ca9/ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901", size = 10783001, upload-time = "2025-07-11T13:20:35.534Z" },
+ { url = "https://files.pythonhosted.org/packages/0d/ca/dd64a9ce56d9ed6cad109606ac014860b1c217c883e93bf61536400ba107/ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0", size = 10269641, upload-time = "2025-07-11T13:20:38.459Z" },
+ { url = "https://files.pythonhosted.org/packages/63/5c/2be545034c6bd5ce5bb740ced3e7014d7916f4c445974be11d2a406d5088/ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6", size = 11875059, upload-time = "2025-07-11T13:20:41.517Z" },
+ { url = "https://files.pythonhosted.org/packages/8e/d4/a74ef1e801ceb5855e9527dae105eaff136afcb9cc4d2056d44feb0e4792/ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc", size = 12658890, upload-time = "2025-07-11T13:20:44.442Z" },
+ { url = "https://files.pythonhosted.org/packages/13/c8/1057916416de02e6d7c9bcd550868a49b72df94e3cca0aeb77457dcd9644/ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687", size = 12232008, upload-time = "2025-07-11T13:20:47.374Z" },
+ { url = "https://files.pythonhosted.org/packages/f5/59/4f7c130cc25220392051fadfe15f63ed70001487eca21d1796db46cbcc04/ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e", size = 11499096, upload-time = "2025-07-11T13:20:50.348Z" },
+ { url = "https://files.pythonhosted.org/packages/d4/01/a0ad24a5d2ed6be03a312e30d32d4e3904bfdbc1cdbe63c47be9d0e82c79/ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311", size = 11688307, upload-time = "2025-07-11T13:20:52.945Z" },
+ { url = "https://files.pythonhosted.org/packages/93/72/08f9e826085b1f57c9a0226e48acb27643ff19b61516a34c6cab9d6ff3fa/ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07", size = 10661020, upload-time = "2025-07-11T13:20:55.799Z" },
+ { url = "https://files.pythonhosted.org/packages/80/a0/68da1250d12893466c78e54b4a0ff381370a33d848804bb51279367fc688/ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12", size = 10246300, upload-time = "2025-07-11T13:20:58.222Z" },
+ { url = "https://files.pythonhosted.org/packages/6a/22/5f0093d556403e04b6fd0984fc0fb32fbb6f6ce116828fd54306a946f444/ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b", size = 11263119, upload-time = "2025-07-11T13:21:01.503Z" },
+ { url = "https://files.pythonhosted.org/packages/92/c9/f4c0b69bdaffb9968ba40dd5fa7df354ae0c73d01f988601d8fac0c639b1/ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f", size = 11746990, upload-time = "2025-07-11T13:21:04.524Z" },
+ { url = "https://files.pythonhosted.org/packages/fe/84/7cc7bd73924ee6be4724be0db5414a4a2ed82d06b30827342315a1be9e9c/ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d", size = 10589263, upload-time = "2025-07-11T13:21:07.148Z" },
+ { url = "https://files.pythonhosted.org/packages/07/87/c070f5f027bd81f3efee7d14cb4d84067ecf67a3a8efb43aadfc72aa79a6/ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7", size = 11695072, upload-time = "2025-07-11T13:21:11.004Z" },
+ { url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" },
]
[[package]]
@@ -5297,6 +5319,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1c/fc/9ba22f01b5cdacc8f5ed0d22304718d2c758fce3fd49a5372b886a86f37c/sqlalchemy-2.0.41-py3-none-any.whl", hash = "sha256:57df5dc6fdb5ed1a88a1ed2195fd31927e705cad62dedd86b46972752a80f576", size = 1911224, upload-time = "2025-05-14T17:39:42.154Z" },
]
+[[package]]
+name = "sseclient-py"
+version = "1.8.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e8/ed/3df5ab8bb0c12f86c28d0cadb11ed1de44a92ed35ce7ff4fd5518a809325/sseclient-py-1.8.0.tar.gz", hash = "sha256:c547c5c1a7633230a38dc599a21a2dc638f9b5c297286b48b46b935c71fac3e8", size = 7791, upload-time = "2023-09-01T19:39:20.45Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/49/58/97655efdfeb5b4eeab85b1fc5d3fa1023661246c2ab2a26ea8e47402d4f2/sseclient_py-1.8.0-py2.py3-none-any.whl", hash = "sha256:4ecca6dc0b9f963f8384e9d7fd529bf93dd7d708144c4fb5da0e0a1a926fee83", size = 8828, upload-time = "2023-09-01T19:39:17.627Z" },
+]
+
[[package]]
name = "starlette"
version = "0.41.0"
@@ -5599,11 +5630,11 @@ wheels = [
[[package]]
name = "types-aiofiles"
-version = "24.1.0.20250606"
+version = "24.1.0.20250708"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/64/6e/fac4ffc896cb3faf2ac5d23747b65dd8bae1d9ee23305d1a3b12111c3989/types_aiofiles-24.1.0.20250606.tar.gz", hash = "sha256:48f9e26d2738a21e0b0f19381f713dcdb852a36727da8414b1ada145d40a18fe", size = 14364, upload-time = "2025-06-06T03:09:26.515Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/4a/d6/5c44761bc11cb5c7505013a39f397a9016bfb3a5c932032b2db16c38b87b/types_aiofiles-24.1.0.20250708.tar.gz", hash = "sha256:c8207ed7385491ce5ba94da02658164ebd66b69a44e892288c9f20cbbf5284ff", size = 14322, upload-time = "2025-07-08T03:14:44.814Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/71/de/f2fa2ab8a5943898e93d8036941e05bfd1e1f377a675ee52c7c307dccb75/types_aiofiles-24.1.0.20250606-py3-none-any.whl", hash = "sha256:e568c53fb9017c80897a9aa15c74bf43b7ee90e412286ec1e0912b6e79301aee", size = 14276, upload-time = "2025-06-06T03:09:25.662Z" },
+ { url = "https://files.pythonhosted.org/packages/44/e9/4e0cc79c630040aae0634ac9393341dc2aff1a5be454be9741cc6cc8989f/types_aiofiles-24.1.0.20250708-py3-none-any.whl", hash = "sha256:07f8f06465fd415d9293467d1c66cd074b2c3b62b679e26e353e560a8cf63720", size = 14320, upload-time = "2025-07-08T03:14:44.009Z" },
]
[[package]]
@@ -5659,11 +5690,11 @@ wheels = [
[[package]]
name = "types-defusedxml"
-version = "0.7.0.20250516"
+version = "0.7.0.20250708"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/55/9d/3ba8b80536402f1a125bc5a44d82ab686aafa55a85f56160e076b2ac30de/types_defusedxml-0.7.0.20250516.tar.gz", hash = "sha256:164c2945077fa450f24ed09633f8b3a80694687fefbbc1cba5f24e4ba570666b", size = 10298, upload-time = "2025-05-16T03:08:18.951Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/b9/4b/79d046a7211e110afd885be04bb9423546df2a662ed28251512d60e51fb6/types_defusedxml-0.7.0.20250708.tar.gz", hash = "sha256:7b785780cc11c18a1af086308bf94bf53a0907943a1d145dbe00189bef323cb8", size = 10541, upload-time = "2025-07-08T03:14:33.325Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/2e/7b/567b0978150edccf7fa3aa8f2566ea9c3ffc9481ce7d64428166934d6d7f/types_defusedxml-0.7.0.20250516-py3-none-any.whl", hash = "sha256:00e793e5c385c3e142d7c2acc3b4ccea2fe0828cee11e35501f0ba40386630a0", size = 12576, upload-time = "2025-05-16T03:08:17.892Z" },
+ { url = "https://files.pythonhosted.org/packages/24/f8/870de7fbd5fee5643f05061db948df6bd574a05a42aee91e37ad47c999ef/types_defusedxml-0.7.0.20250708-py3-none-any.whl", hash = "sha256:cc426cbc31c61a0f1b1c2ad9b9ef9ef846645f28fd708cd7727a6353b5c52e54", size = 13478, upload-time = "2025-07-08T03:14:32.633Z" },
]
[[package]]
@@ -5677,11 +5708,11 @@ wheels = [
[[package]]
name = "types-docutils"
-version = "0.21.0.20250604"
+version = "0.21.0.20250708"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/ef/d0/d28035370d669f14d4e23bd63d093207331f361afa24d2686d2c3fe6be8d/types_docutils-0.21.0.20250604.tar.gz", hash = "sha256:5a9cc7f5a4c5ef694aa0abc61111e0b1376a53dee90d65757f77f31acfcca8f2", size = 40953, upload-time = "2025-06-04T03:10:27.439Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/39/86/24394a71a04f416ca03df51863a3d3e2cd0542fdc40989188dca30ffb5bf/types_docutils-0.21.0.20250708.tar.gz", hash = "sha256:5625a82a9a2f26d8384545607c157e023a48ed60d940dfc738db125282864172", size = 42011, upload-time = "2025-07-08T03:14:24.214Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/89/91/887e9591c1ee50dfbf7c2fa2f3f51bc6db683013b6d2b0cd3983adf3d502/types_docutils-0.21.0.20250604-py3-none-any.whl", hash = "sha256:bfa8628176c06a80cdd1d6f3fb32e972e042db53538596488dfe0e9c5962b222", size = 65915, upload-time = "2025-06-04T03:10:26.067Z" },
+ { url = "https://files.pythonhosted.org/packages/bd/17/8c1153fc1576a0dcffdd157c69a12863c3f9485054256f6791ea17d95aed/types_docutils-0.21.0.20250708-py3-none-any.whl", hash = "sha256:166630d1aec18b9ca02547873210e04bf7674ba8f8da9cd9e6a5e77dc99372c2", size = 67953, upload-time = "2025-07-08T03:14:23.057Z" },
]
[[package]]
@@ -5733,11 +5764,11 @@ wheels = [
[[package]]
name = "types-html5lib"
-version = "1.1.11.20250516"
+version = "1.1.11.20250708"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/d0/ed/9f092ff479e2b5598941855f314a22953bb04b5fb38bcba3f880feb833ba/types_html5lib-1.1.11.20250516.tar.gz", hash = "sha256:65043a6718c97f7d52567cc0cdf41efbfc33b1f92c6c0c5e19f60a7ec69ae720", size = 16136, upload-time = "2025-05-16T03:07:12.231Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/d4/3b/1f5ba4358cfc1421cced5cdb9d2b08b4b99e4f9a41da88ce079f6d1a7bf1/types_html5lib-1.1.11.20250708.tar.gz", hash = "sha256:24321720fdbac71cee50d5a4bec9b7448495b7217974cffe3fcf1ede4eef7afe", size = 16799, upload-time = "2025-07-08T03:13:53.14Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/cc/3b/cb5b23c7b51bf48b8c9f175abb9dce2f1ecd2d2c25f92ea9f4e3720e9398/types_html5lib-1.1.11.20250516-py3-none-any.whl", hash = "sha256:5e407b14b1bd2b9b1107cbd1e2e19d4a0c46d60febd231c7ab7313d7405663c1", size = 21770, upload-time = "2025-05-16T03:07:11.102Z" },
+ { url = "https://files.pythonhosted.org/packages/a8/50/5fc23cf647eee23acdd337c8150861d39980cf11f33dd87f78e87d2a4bad/types_html5lib-1.1.11.20250708-py3-none-any.whl", hash = "sha256:bb898066b155de7081cb182179e2ded31b9e0e234605e2cb46536894e68a6954", size = 22913, upload-time = "2025-07-08T03:13:52.098Z" },
]
[[package]]
@@ -5856,11 +5887,11 @@ wheels = [
[[package]]
name = "types-pymysql"
-version = "1.1.0.20250516"
+version = "1.1.0.20250708"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/db/11/cdaa90b82cb25c5e04e75f0b0616872aa5775b001096779375084f8dbbcf/types_pymysql-1.1.0.20250516.tar.gz", hash = "sha256:fea4a9776101cf893dfc868f42ce10d2e46dcc498c792cc7c9c0fe00cb744234", size = 19640, upload-time = "2025-05-16T03:06:54.568Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/65/a3/db349a06c64b8c041c165fc470b81d37404ec342014625c7a6b7f7a4f680/types_pymysql-1.1.0.20250708.tar.gz", hash = "sha256:2cbd7cfcf9313eda784910578c4f1d06f8cc03a15cd30ce588aa92dd6255011d", size = 21715, upload-time = "2025-07-08T03:13:56.463Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/ab/64/129656e04ddda35d69faae914ce67cf60d83407ddd7afdef1e7c50bbb74a/types_pymysql-1.1.0.20250516-py3-none-any.whl", hash = "sha256:41c87a832e3ff503d5120cc6cebd64f6dcb3c407d9580a98b2cb3e3bcd109aa6", size = 20328, upload-time = "2025-05-16T03:06:53.681Z" },
+ { url = "https://files.pythonhosted.org/packages/88/e5/7f72c520f527175b6455e955426fd4f971128b4fa2f8ab2f505f254a1ddc/types_pymysql-1.1.0.20250708-py3-none-any.whl", hash = "sha256:9252966d2795945b2a7a53d5cdc49fe8e4e2f3dde4c104ed7fc782a83114e365", size = 22860, upload-time = "2025-07-08T03:13:55.367Z" },
]
[[package]]
@@ -5878,20 +5909,20 @@ wheels = [
[[package]]
name = "types-python-dateutil"
-version = "2.9.0.20250516"
+version = "2.9.0.20250708"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/ef/88/d65ed807393285204ab6e2801e5d11fbbea811adcaa979a2ed3b67a5ef41/types_python_dateutil-2.9.0.20250516.tar.gz", hash = "sha256:13e80d6c9c47df23ad773d54b2826bd52dbbb41be87c3f339381c1700ad21ee5", size = 13943, upload-time = "2025-05-16T03:06:58.385Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/c9/95/6bdde7607da2e1e99ec1c1672a759d42f26644bbacf939916e086db34870/types_python_dateutil-2.9.0.20250708.tar.gz", hash = "sha256:ccdbd75dab2d6c9696c350579f34cffe2c281e4c5f27a585b2a2438dd1d5c8ab", size = 15834, upload-time = "2025-07-08T03:14:03.382Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/c5/3f/b0e8db149896005adc938a1e7f371d6d7e9eca4053a29b108978ed15e0c2/types_python_dateutil-2.9.0.20250516-py3-none-any.whl", hash = "sha256:2b2b3f57f9c6a61fba26a9c0ffb9ea5681c9b83e69cd897c6b5f668d9c0cab93", size = 14356, upload-time = "2025-05-16T03:06:57.249Z" },
+ { url = "https://files.pythonhosted.org/packages/72/52/43e70a8e57fefb172c22a21000b03ebcc15e47e97f5cb8495b9c2832efb4/types_python_dateutil-2.9.0.20250708-py3-none-any.whl", hash = "sha256:4d6d0cc1cc4d24a2dc3816024e502564094497b713f7befda4d5bc7a8e3fd21f", size = 17724, upload-time = "2025-07-08T03:14:02.593Z" },
]
[[package]]
name = "types-python-http-client"
-version = "3.3.7.20240910"
+version = "3.3.7.20250708"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/e1/d7/bb2754c2d1b20c1890593ec89799c99e8875b04f474197c41354f41e9d31/types-python-http-client-3.3.7.20240910.tar.gz", hash = "sha256:8a6ebd30ad4b90a329ace69c240291a6176388624693bc971a5ecaa7e9b05074", size = 2804, upload-time = "2024-09-10T02:38:31.608Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/55/a0/0ad93698a3ebc6846ca23aca20ff6f6f8ebe7b4f0c1de7f19e87c03dbe8f/types_python_http_client-3.3.7.20250708.tar.gz", hash = "sha256:5f85b32dc64671a4e5e016142169aa187c5abed0b196680944e4efd3d5ce3322", size = 7707, upload-time = "2025-07-08T03:14:36.197Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/64/95/8f492d37d99630e096acbb4071788483282a34a73ae89dd1a5727f4189cc/types_python_http_client-3.3.7.20240910-py3-none-any.whl", hash = "sha256:58941bd986fb8bb0f4f782ef376be145ece8023f391364fbcd22bd26b13a140e", size = 3917, upload-time = "2024-09-10T02:38:30.261Z" },
+ { url = "https://files.pythonhosted.org/packages/85/4f/b88274658cf489e35175be8571c970e9a1219713bafd8fc9e166d7351ecb/types_python_http_client-3.3.7.20250708-py3-none-any.whl", hash = "sha256:e2fc253859decab36713d82fc7f205868c3ddeaee79dbb55956ad9ca77abe12b", size = 8890, upload-time = "2025-07-08T03:14:35.506Z" },
]
[[package]]
@@ -6040,11 +6071,11 @@ wheels = [
[[package]]
name = "typing-extensions"
-version = "4.14.0"
+version = "4.14.1"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/d1/bc/51647cd02527e87d05cb083ccc402f93e441606ff1f01739a62c8ad09ba5/typing_extensions-4.14.0.tar.gz", hash = "sha256:8676b788e32f02ab42d9e7c61324048ae4c6d844a399eebace3d4979d75ceef4", size = 107423, upload-time = "2025-06-02T14:52:11.399Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/98/5a/da40306b885cc8c09109dc2e1abd358d5684b1425678151cdaed4731c822/typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36", size = 107673, upload-time = "2025-07-04T13:28:34.16Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/69/e0/552843e0d356fbb5256d21449fa957fa4eff3bbc135a74a691ee70c7c5da/typing_extensions-4.14.0-py3-none-any.whl", hash = "sha256:a1514509136dd0b477638fc68d6a91497af5076466ad0fa6c338e44e359944af", size = 43839, upload-time = "2025-06-02T14:52:10.026Z" },
+ { url = "https://files.pythonhosted.org/packages/b5/00/d631e67a838026495268c2f6884f3711a15a9a2a96cd244fdaea53b823fb/typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76", size = 43906, upload-time = "2025-07-04T13:28:32.743Z" },
]
[[package]]
@@ -6172,7 +6203,7 @@ pptx = [
[[package]]
name = "unstructured-client"
-version = "0.37.4"
+version = "0.38.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiofiles" },
@@ -6183,9 +6214,9 @@ dependencies = [
{ name = "pypdf" },
{ name = "requests-toolbelt" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/6c/6f/8dd20dab879f25074d6abfbb98f77bb8efeea0ae1bdf9a414b3e73c152b6/unstructured_client-0.37.4.tar.gz", hash = "sha256:5a4029563c2f79de098374fd8a99090719df325b4bdcfa3a87820908f2c83e6c", size = 90481, upload-time = "2025-07-01T16:40:09.877Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/85/60/412092671bfc4952640739f2c0c9b2f4c8af26a3c921738fd12621b4ddd8/unstructured_client-0.38.1.tar.gz", hash = "sha256:43ab0670dd8ff53d71e74f9b6dfe490a84a5303dab80a4873e118a840c6d46ca", size = 91781, upload-time = "2025-07-03T15:46:35.054Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/93/09/4399b0c32564b1a19fef943b5acea5a16fa0c6aa7a320065ce726b8245c1/unstructured_client-0.37.4-py3-none-any.whl", hash = "sha256:31975c0ea4408e369e6aad11c9e746d1f3f14013ac5c89f9f8dbada3a21dcec0", size = 211242, upload-time = "2025-07-01T16:40:08.642Z" },
+ { url = "https://files.pythonhosted.org/packages/26/e0/8c249f00ba85fb4aba5c541463312befbfbf491105ff5c06e508089467be/unstructured_client-0.38.1-py3-none-any.whl", hash = "sha256:71e5467870d0a0119c788c29ec8baf5c0f7123f424affc9d6682eeeb7b8d45fa", size = 212626, upload-time = "2025-07-03T15:46:33.929Z" },
]
[[package]]
@@ -6220,11 +6251,11 @@ wheels = [
[[package]]
name = "uuid6"
-version = "2025.0.0"
+version = "2025.0.1"
source = { registry = "https://pypi.org/simple" }
-sdist = { url = "https://files.pythonhosted.org/packages/3f/49/06a089c184580f510e20226d9a081e4323d13db2fbc92d566697b5395c1e/uuid6-2025.0.0.tar.gz", hash = "sha256:bb78aa300e29db89b00410371d0c1f1824e59e29995a9daa3dedc8033d1d84ec", size = 13941, upload-time = "2025-06-11T20:02:05.324Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/ca/b7/4c0f736ca824b3a25b15e8213d1bcfc15f8ac2ae48d1b445b310892dc4da/uuid6-2025.0.1.tar.gz", hash = "sha256:cd0af94fa428675a44e32c5319ec5a3485225ba2179eefcf4c3f205ae30a81bd", size = 13932, upload-time = "2025-07-04T18:30:35.186Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/0a/50/4da47101af45b6cfa291559577993b52ee4399b3cd54ba307574a11e4f3a/uuid6-2025.0.0-py3-none-any.whl", hash = "sha256:2c73405ff5333c7181443958c6865e0d1b9b816bb160549e8d80ba186263cb3a", size = 7001, upload-time = "2025-06-11T20:02:04.521Z" },
+ { url = "https://files.pythonhosted.org/packages/3d/b2/93faaab7962e2aa8d6e174afb6f76be2ca0ce89fde14d3af835acebcaa59/uuid6-2025.0.1-py3-none-any.whl", hash = "sha256:80530ce4d02a93cdf82e7122ca0da3ebbbc269790ec1cb902481fa3e9cc9ff99", size = 6979, upload-time = "2025-07-04T18:30:34.001Z" },
]
[[package]]
diff --git a/docker/.env.example b/docker/.env.example
index a024566c8f..88cc544730 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -47,6 +47,11 @@ APP_WEB_URL=
# ensuring port 5001 is externally accessible (see docker-compose.yaml).
FILES_URL=
+# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
+# Set this to the internal Docker service URL for proper plugin file access.
+# Example: INTERNAL_FILES_URL=http://api:5001
+INTERNAL_FILES_URL=
+
# ------------------------------
# Server Configuration
# ------------------------------
@@ -209,6 +214,10 @@ SQLALCHEMY_POOL_SIZE=30
SQLALCHEMY_POOL_RECYCLE=3600
# Whether to print SQL, default is false.
SQLALCHEMY_ECHO=false
+# If True, will test connections for liveness upon each checkout
+SQLALCHEMY_POOL_PRE_PING=false
+# Whether to enable the Last in first out option or use default FIFO queue if is false
+SQLALCHEMY_POOL_USE_LIFO=false
# Maximum number of connections to the database
# Default is 100
@@ -274,12 +283,14 @@ REDIS_CLUSTERS_PASSWORD=
# Celery Configuration
# ------------------------------
-# Use redis as the broker, and redis db 1 for celery broker.
-# Format as follows: `redis://:@:/`
+# Use standalone redis as the broker, and redis db 1 for celery broker. (redis_username is usually set by defualt as empty)
+# Format as follows: `redis://:@:/`.
# Example: redis://:difyai123456@redis:6379/1
-# If use Redis Sentinel, format as follows: `sentinel://:@:/`
-# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1
+# If use Redis Sentinel, format as follows: `sentinel://:@:/`
+# For high availability, you can configure multiple Sentinel nodes (if provided) separated by semicolons like below example:
+# Example: sentinel://:difyai123456@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1
CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1
+CELERY_BACKEND=redis
BROKER_USE_SSL=false
# If you are using Redis Sentinel for high availability, configure the following settings.
@@ -402,6 +413,8 @@ SUPABASE_URL=your-server-url
# The type of vector store to use.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
+# Prefix used to create collection name in vector database
+VECTOR_INDEX_NAME_PREFIX=Vector_index
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT=http://weaviate:8080
@@ -763,6 +776,8 @@ INVITE_EXPIRY_HOURS=72
# Reset password token valid time (minutes),
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
+CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
+OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
# The sandbox service endpoint.
CODE_EXECUTION_ENDPOINT=http://sandbox:8194
@@ -794,6 +809,19 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
+# Repository configuration
+# Core workflow execution repository implementation
+CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
+
+# Core workflow node execution repository implementation
+CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
+
+# API workflow node execution repository implementation
+API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
+
+# API workflow run repository implementation
+API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
+
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
@@ -826,6 +854,9 @@ MAX_ITERATIONS_NUM=99
# The timeout for the text generation in millisecond
TEXT_GENERATION_TIMEOUT_MS=60000
+# Allow rendering unsafe URLs which have "data:" scheme.
+ALLOW_UNSAFE_DATA_SCHEME=false
+
# ------------------------------
# Environment Variables for db Service
# ------------------------------
@@ -958,7 +989,7 @@ NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3
# Nginx performance tuning
NGINX_WORKER_PROCESSES=auto
-NGINX_CLIENT_MAX_BODY_SIZE=15M
+NGINX_CLIENT_MAX_BODY_SIZE=100M
NGINX_KEEPALIVE_TIMEOUT=65
# Proxy settings
@@ -1114,6 +1145,8 @@ PLUGIN_VOLCENGINE_TOS_REGION=
# OTLP Collector Configuration
# ------------------------------
ENABLE_OTEL=false
+OTLP_TRACE_ENDPOINT=
+OTLP_METRIC_ENDPOINT=
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_OTLP_PROTOCOL=
@@ -1135,3 +1168,13 @@ QUEUE_MONITOR_THRESHOLD=200
QUEUE_MONITOR_ALERT_EMAILS=
# Monitor interval in minutes, default is 30 minutes
QUEUE_MONITOR_INTERVAL=30
+
+# Celery schedule tasks configuration
+ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false
+ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
+ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
+ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
+ENABLE_CLEAN_MESSAGES=false
+ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
+ENABLE_DATASETS_QUEUE_MONITOR=false
+ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml
index d45f8f8bfa..394a068200 100644
--- a/docker/docker-compose-template.yaml
+++ b/docker/docker-compose-template.yaml
@@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services:
# API service
api:
- image: langgenius/dify-api:1.5.1
+ image: langgenius/dify-api:1.7.0
restart: always
environment:
# Use the shared environment variables.
@@ -31,7 +31,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
- image: langgenius/dify-api:1.5.1
+ image: langgenius/dify-api:1.7.0
restart: always
environment:
# Use the shared environment variables.
@@ -55,9 +55,28 @@ services:
- ssrf_proxy_network
- default
+ # worker_beat service
+ # Celery beat for scheduling periodic tasks.
+ worker_beat:
+ image: langgenius/dify-api:1.7.0
+ restart: always
+ environment:
+ # Use the shared environment variables.
+ <<: *shared-api-worker-env
+ # Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks.
+ MODE: beat
+ depends_on:
+ db:
+ condition: service_healthy
+ redis:
+ condition: service_started
+ networks:
+ - ssrf_proxy_network
+ - default
+
# Frontend web application.
web:
- image: langgenius/dify-web:1.5.1
+ image: langgenius/dify-web:1.7.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@@ -67,6 +86,7 @@ services:
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
CSP_WHITELIST: ${CSP_WHITELIST:-}
ALLOW_EMBED: ${ALLOW_EMBED:-false}
+ ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
@@ -142,7 +162,7 @@ services:
# plugin daemon
plugin_daemon:
- image: langgenius/dify-plugin-daemon:0.1.3-local
+ image: langgenius/dify-plugin-daemon:0.2.0-local
restart: always
environment:
# Use the shared environment variables.
@@ -265,7 +285,7 @@ services:
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
- NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M}
+ NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml
index 0b1885755b..3408fef0c2 100644
--- a/docker/docker-compose.middleware.yaml
+++ b/docker/docker-compose.middleware.yaml
@@ -71,7 +71,7 @@ services:
# plugin daemon
plugin_daemon:
- image: langgenius/dify-plugin-daemon:0.1.3-local
+ image: langgenius/dify-plugin-daemon:0.2.0-local
restart: always
env_file:
- ./middleware.env
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 7f91fd8796..c2ef2ff723 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -11,6 +11,7 @@ x-shared-env: &shared-api-worker-env
APP_API_URL: ${APP_API_URL:-}
APP_WEB_URL: ${APP_WEB_URL:-}
FILES_URL: ${FILES_URL:-}
+ INTERNAL_FILES_URL: ${INTERNAL_FILES_URL:-}
LOG_LEVEL: ${LOG_LEVEL:-INFO}
LOG_FILE: ${LOG_FILE:-/app/logs/server.log}
LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20}
@@ -55,6 +56,8 @@ x-shared-env: &shared-api-worker-env
SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30}
SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600}
SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false}
+ SQLALCHEMY_POOL_PRE_PING: ${SQLALCHEMY_POOL_PRE_PING:-false}
+ SQLALCHEMY_POOL_USE_LIFO: ${SQLALCHEMY_POOL_USE_LIFO:-false}
POSTGRES_MAX_CONNECTIONS: ${POSTGRES_MAX_CONNECTIONS:-100}
POSTGRES_SHARED_BUFFERS: ${POSTGRES_SHARED_BUFFERS:-128MB}
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}
@@ -76,6 +79,7 @@ x-shared-env: &shared-api-worker-env
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
+ CELERY_BACKEND: ${CELERY_BACKEND:-redis}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false}
CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-}
@@ -132,6 +136,7 @@ x-shared-env: &shared-api-worker-env
SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key}
SUPABASE_URL: ${SUPABASE_URL:-your-server-url}
VECTOR_STORE: ${VECTOR_STORE:-weaviate}
+ VECTOR_INDEX_NAME_PREFIX: ${VECTOR_INDEX_NAME_PREFIX:-Vector_index}
WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080}
WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih}
QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333}
@@ -332,6 +337,8 @@ x-shared-env: &shared-api-worker-env
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000}
INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72}
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5}
+ CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: ${CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES:-5}
+ OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: ${OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES:-5}
CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194}
CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox}
CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807}
@@ -353,6 +360,10 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
+ CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository}
+ CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
+ API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
+ API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@@ -364,6 +375,7 @@ x-shared-env: &shared-api-worker-env
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
+ ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}}
@@ -420,7 +432,7 @@ x-shared-env: &shared-api-worker-env
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
- NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M}
+ NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
@@ -498,6 +510,8 @@ x-shared-env: &shared-api-worker-env
PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ENABLE_OTEL: ${ENABLE_OTEL:-false}
+ OTLP_TRACE_ENDPOINT: ${OTLP_TRACE_ENDPOINT:-}
+ OTLP_METRIC_ENDPOINT: ${OTLP_METRIC_ENDPOINT:-}
OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318}
OTLP_API_KEY: ${OTLP_API_KEY:-}
OTEL_EXPORTER_OTLP_PROTOCOL: ${OTEL_EXPORTER_OTLP_PROTOCOL:-}
@@ -513,11 +527,19 @@ x-shared-env: &shared-api-worker-env
QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
+ ENABLE_CLEAN_EMBEDDING_CACHE_TASK: ${ENABLE_CLEAN_EMBEDDING_CACHE_TASK:-false}
+ ENABLE_CLEAN_UNUSED_DATASETS_TASK: ${ENABLE_CLEAN_UNUSED_DATASETS_TASK:-false}
+ ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false}
+ ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: ${ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:-false}
+ ENABLE_CLEAN_MESSAGES: ${ENABLE_CLEAN_MESSAGES:-false}
+ ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false}
+ ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false}
+ ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true}
services:
# API service
api:
- image: langgenius/dify-api:1.5.1
+ image: langgenius/dify-api:1.7.0
restart: always
environment:
# Use the shared environment variables.
@@ -546,7 +568,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
- image: langgenius/dify-api:1.5.1
+ image: langgenius/dify-api:1.7.0
restart: always
environment:
# Use the shared environment variables.
@@ -570,9 +592,28 @@ services:
- ssrf_proxy_network
- default
+ # worker_beat service
+ # Celery beat for scheduling periodic tasks.
+ worker_beat:
+ image: langgenius/dify-api:1.7.0
+ restart: always
+ environment:
+ # Use the shared environment variables.
+ <<: *shared-api-worker-env
+ # Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks.
+ MODE: beat
+ depends_on:
+ db:
+ condition: service_healthy
+ redis:
+ condition: service_started
+ networks:
+ - ssrf_proxy_network
+ - default
+
# Frontend web application.
web:
- image: langgenius/dify-web:1.5.1
+ image: langgenius/dify-web:1.7.0
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@@ -582,6 +623,7 @@ services:
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
CSP_WHITELIST: ${CSP_WHITELIST:-}
ALLOW_EMBED: ${ALLOW_EMBED:-false}
+ ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
@@ -657,7 +699,7 @@ services:
# plugin daemon
plugin_daemon:
- image: langgenius/dify-plugin-daemon:0.1.3-local
+ image: langgenius/dify-plugin-daemon:0.2.0-local
restart: always
environment:
# Use the shared environment variables.
@@ -780,7 +822,7 @@ services:
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
- NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M}
+ NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
diff --git a/docker/nginx/conf.d/default.conf.template b/docker/nginx/conf.d/default.conf.template
index a458412d1e..48d7da8cf5 100644
--- a/docker/nginx/conf.d/default.conf.template
+++ b/docker/nginx/conf.d/default.conf.template
@@ -39,7 +39,10 @@ server {
proxy_pass http://web:3000;
include proxy.conf;
}
-
+ location /mcp {
+ proxy_pass http://api:5001;
+ include proxy.conf;
+ }
# placeholder for acme challenge location
${ACME_CHALLENGE_LOCATION}
diff --git a/sdks/python-client/README.md b/sdks/python-client/README.md
index 8949ef08fa..7401fd2fd4 100644
--- a/sdks/python-client/README.md
+++ b/sdks/python-client/README.md
@@ -183,3 +183,42 @@ rename_conversation_response.raise_for_status()
print('[rename result]')
print(rename_conversation_response.json())
```
+
+* Using the Workflow Client
+```python
+import json
+import requests
+from dify_client import WorkflowClient
+
+api_key = "your_api_key"
+
+# Initialize Workflow Client
+client = WorkflowClient(api_key)
+
+# Prepare parameters for Workflow Client
+user_id = "your_user_id"
+context = "previous user interaction / metadata"
+user_prompt = "What is the capital of France?"
+
+inputs = {
+ "context": context,
+ "user_prompt": user_prompt,
+ # Add other input fields expected by your workflow (e.g., additional context, task parameters)
+
+}
+
+# Set response mode (default: streaming)
+response_mode = "blocking"
+
+# Run the workflow
+response = client.run(inputs=inputs, response_mode=response_mode, user=user_id)
+response.raise_for_status()
+
+# Parse result
+result = json.loads(response.text)
+
+answer = result.get("data").get("outputs")
+
+print(answer["answer"])
+
+```
diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py
index 6fa9d190e5..d00c207afa 100644
--- a/sdks/python-client/dify_client/__init__.py
+++ b/sdks/python-client/dify_client/__init__.py
@@ -1 +1,7 @@
-from dify_client.client import ChatClient, CompletionClient, DifyClient
+from dify_client.client import (
+ ChatClient,
+ CompletionClient,
+ WorkflowClient,
+ KnowledgeBaseClient,
+ DifyClient,
+)
diff --git a/tests/unit_tests/events/test_provider_update_deadlock_prevention.py b/tests/unit_tests/events/test_provider_update_deadlock_prevention.py
deleted file mode 100644
index 47c175acd7..0000000000
--- a/tests/unit_tests/events/test_provider_update_deadlock_prevention.py
+++ /dev/null
@@ -1,248 +0,0 @@
-import threading
-from unittest.mock import Mock, patch
-
-from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
-from core.entities.provider_entities import QuotaUnit
-from events.event_handlers.update_provider_when_message_created import (
- handle,
- get_update_stats,
-)
-from models.provider import ProviderType
-from sqlalchemy.exc import OperationalError
-
-
-class TestProviderUpdateDeadlockPrevention:
- """Test suite for deadlock prevention in Provider updates."""
-
- def setup_method(self):
- """Setup test fixtures."""
- self.mock_message = Mock()
- self.mock_message.answer_tokens = 100
-
- self.mock_app_config = Mock()
- self.mock_app_config.tenant_id = "test-tenant-123"
-
- self.mock_model_conf = Mock()
- self.mock_model_conf.provider = "openai"
-
- self.mock_system_config = Mock()
- self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
-
- self.mock_provider_config = Mock()
- self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
- self.mock_provider_config.system_configuration = self.mock_system_config
-
- self.mock_provider_bundle = Mock()
- self.mock_provider_bundle.configuration = self.mock_provider_config
-
- self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
-
- self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
- self.mock_generate_entity.app_config = self.mock_app_config
- self.mock_generate_entity.model_conf = self.mock_model_conf
-
- @patch("events.event_handlers.update_provider_when_message_created.db")
- def test_consolidated_handler_basic_functionality(self, mock_db):
- """Test that the consolidated handler performs both updates correctly."""
- # Setup mock query chain
- mock_query = Mock()
- mock_db.session.query.return_value = mock_query
- mock_query.filter.return_value = mock_query
- mock_query.order_by.return_value = mock_query
- mock_query.update.return_value = 1 # 1 row affected
-
- # Call the handler
- handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
-
- # Verify db.session.query was called
- assert mock_db.session.query.called
-
- # Verify commit was called
- mock_db.session.commit.assert_called_once()
-
- # Verify no rollback was called
- assert not mock_db.session.rollback.called
-
- @patch("events.event_handlers.update_provider_when_message_created.db")
- def test_deadlock_retry_mechanism(self, mock_db):
- """Test that deadlock errors trigger retry logic."""
- # Setup mock to raise deadlock error on first attempt, succeed on second
- mock_query = Mock()
- mock_db.session.query.return_value = mock_query
- mock_query.filter.return_value = mock_query
- mock_query.order_by.return_value = mock_query
- mock_query.update.return_value = 1
-
- # First call raises deadlock, second succeeds
- mock_db.session.commit.side_effect = [
- OperationalError("deadlock detected", None, None),
- None, # Success on retry
- ]
-
- # Call the handler
- handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
-
- # Verify commit was called twice (original + retry)
- assert mock_db.session.commit.call_count == 2
-
- # Verify rollback was called once (after first failure)
- mock_db.session.rollback.assert_called_once()
-
- @patch("events.event_handlers.update_provider_when_message_created.db")
- @patch("events.event_handlers.update_provider_when_message_created.time.sleep")
- def test_exponential_backoff_timing(self, mock_sleep, mock_db):
- """Test that retry delays follow exponential backoff pattern."""
- # Setup mock to fail twice, succeed on third attempt
- mock_query = Mock()
- mock_db.session.query.return_value = mock_query
- mock_query.filter.return_value = mock_query
- mock_query.order_by.return_value = mock_query
- mock_query.update.return_value = 1
-
- mock_db.session.commit.side_effect = [
- OperationalError("deadlock detected", None, None),
- OperationalError("deadlock detected", None, None),
- None, # Success on third attempt
- ]
-
- # Call the handler
- handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
-
- # Verify sleep was called twice with increasing delays
- assert mock_sleep.call_count == 2
-
- # First delay should be around 0.1s + jitter
- first_delay = mock_sleep.call_args_list[0][0][0]
- assert 0.1 <= first_delay <= 0.3
-
- # Second delay should be around 0.2s + jitter
- second_delay = mock_sleep.call_args_list[1][0][0]
- assert 0.2 <= second_delay <= 0.4
-
- def test_concurrent_handler_execution(self):
- """Test that multiple handlers can run concurrently without deadlock."""
- results = []
- errors = []
-
- def run_handler():
- try:
- with patch(
- "events.event_handlers.update_provider_when_message_created.db"
- ) as mock_db:
- mock_query = Mock()
- mock_db.session.query.return_value = mock_query
- mock_query.filter.return_value = mock_query
- mock_query.order_by.return_value = mock_query
- mock_query.update.return_value = 1
-
- handle(
- self.mock_message,
- application_generate_entity=self.mock_generate_entity,
- )
- results.append("success")
- except Exception as e:
- errors.append(str(e))
-
- # Run multiple handlers concurrently
- threads = []
- for _ in range(5):
- thread = threading.Thread(target=run_handler)
- threads.append(thread)
- thread.start()
-
- # Wait for all threads to complete
- for thread in threads:
- thread.join(timeout=5)
-
- # Verify all handlers completed successfully
- assert len(results) == 5
- assert len(errors) == 0
-
- def test_performance_stats_tracking(self):
- """Test that performance statistics are tracked correctly."""
- # Reset stats
- stats = get_update_stats()
- initial_total = stats["total_updates"]
-
- with patch(
- "events.event_handlers.update_provider_when_message_created.db"
- ) as mock_db:
- mock_query = Mock()
- mock_db.session.query.return_value = mock_query
- mock_query.filter.return_value = mock_query
- mock_query.order_by.return_value = mock_query
- mock_query.update.return_value = 1
-
- # Call handler
- handle(
- self.mock_message, application_generate_entity=self.mock_generate_entity
- )
-
- # Check that stats were updated
- updated_stats = get_update_stats()
- assert updated_stats["total_updates"] == initial_total + 1
- assert updated_stats["successful_updates"] >= initial_total + 1
-
- def test_non_chat_entity_ignored(self):
- """Test that non-chat entities are ignored by the handler."""
- # Create a non-chat entity
- mock_non_chat_entity = Mock()
- mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
-
- with patch(
- "events.event_handlers.update_provider_when_message_created.db"
- ) as mock_db:
- # Call handler with non-chat entity
- handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
-
- # Verify no database operations were performed
- assert not mock_db.session.query.called
- assert not mock_db.session.commit.called
-
- @patch("events.event_handlers.update_provider_when_message_created.db")
- def test_quota_calculation_tokens(self, mock_db):
- """Test quota calculation for token-based quotas."""
- # Setup token-based quota
- self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
- self.mock_message.answer_tokens = 150
-
- mock_query = Mock()
- mock_db.session.query.return_value = mock_query
- mock_query.filter.return_value = mock_query
- mock_query.order_by.return_value = mock_query
- mock_query.update.return_value = 1
-
- # Call handler
- handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
-
- # Verify update was called with token count
- update_calls = mock_query.update.call_args_list
-
- # Should have at least one call with quota_used update
- quota_update_found = False
- for call in update_calls:
- values = call[0][0] # First argument to update()
- if "quota_used" in values:
- quota_update_found = True
- break
-
- assert quota_update_found
-
- @patch("events.event_handlers.update_provider_when_message_created.db")
- def test_quota_calculation_times(self, mock_db):
- """Test quota calculation for times-based quotas."""
- # Setup times-based quota
- self.mock_system_config.current_quota_type = QuotaUnit.TIMES
-
- mock_query = Mock()
- mock_db.session.query.return_value = mock_query
- mock_query.filter.return_value = mock_query
- mock_query.order_by.return_value = mock_query
- mock_query.update.return_value = 1
-
- # Call handler
- handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
-
- # Verify update was called
- assert mock_query.update.called
- assert mock_db.session.commit.called
diff --git a/web/.env.example b/web/.env.example
index c30064ffed..37bfc939eb 100644
--- a/web/.env.example
+++ b/web/.env.example
@@ -32,6 +32,9 @@ NEXT_PUBLIC_CSP_WHITELIST=
# Default is not allow to embed into iframe to prevent Clickjacking: https://owasp.org/www-community/attacks/Clickjacking
NEXT_PUBLIC_ALLOW_EMBED=
+# Allow rendering unsafe URLs which have "data:" scheme.
+NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME=false
+
# Github Access Token, used for invoking Github API
NEXT_PUBLIC_GITHUB_ACCESS_TOKEN=
# The maximum number of top-k value for RAG.
diff --git a/web/Dockerfile b/web/Dockerfile
index 93eef59815..d59039528c 100644
--- a/web/Dockerfile
+++ b/web/Dockerfile
@@ -6,7 +6,7 @@ LABEL maintainer="takatost@gmail.com"
# RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories
RUN apk add --no-cache tzdata
-RUN npm install -g pnpm@10.11.1
+RUN npm install -g pnpm@10.13.1
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx
index 084adceef2..3d572b926a 100644
--- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx
+++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx
@@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import AppCard from '@/app/components/app/overview/appCard'
import Loading from '@/app/components/base/loading'
+import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card'
import { ToastContext } from '@/app/components/base/toast'
import {
fetchAppDetail,
@@ -31,6 +32,8 @@ const CardView: FC = ({ appId, isInPanel, className }) => {
const appDetail = useAppStore(state => state.appDetail)
const setAppDetail = useAppStore(state => state.setAppDetail)
+ const showMCPCard = isInPanel
+
const updateAppDetail = async () => {
try {
const res = await fetchAppDetail({ url: '/apps', id: appId })
@@ -117,6 +120,11 @@ const CardView: FC = ({ appId, isInPanel, className }) => {
isInPanel={isInPanel}
onChangeStatus={onChangeApiStatus}
/>
+ {showMCPCard && (
+
+ )}
)
}
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx
index 92ba068b2b..2afe451fe1 100644
--- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx
+++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx
@@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import TracingIcon from './tracing-icon'
import ProviderPanel from './provider-panel'
-import type { ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type'
+import type { AliyunConfig, ArizeConfig, LangFuseConfig, LangSmithConfig, OpikConfig, PhoenixConfig, WeaveConfig } from './type'
import { TracingProvider } from './type'
import ProviderConfigModal from './provider-config-modal'
import Indicator from '@/app/components/header/indicator'
@@ -29,7 +29,8 @@ export type PopupProps = {
langFuseConfig: LangFuseConfig | null
opikConfig: OpikConfig | null
weaveConfig: WeaveConfig | null
- onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig) => void
+ aliyunConfig: AliyunConfig | null
+ onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig) => void
onConfigRemoved: (provider: TracingProvider) => void
}
@@ -46,6 +47,7 @@ const ConfigPopup: FC