mirror of https://github.com/langgenius/dify.git
Merge f15464645c into 2c919efa69
This commit is contained in:
commit
fedd22af6a
|
|
@ -1,4 +1,6 @@
|
|||
import base64
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, overload
|
||||
|
||||
from libs import rsa
|
||||
|
||||
|
|
@ -42,3 +44,60 @@ def get_decrypt_decoding(tenant_id: str):
|
|||
|
||||
def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa):
|
||||
return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa)
|
||||
|
||||
|
||||
# =========================
|
||||
# encrypt_secret_keys
|
||||
# =========================
|
||||
|
||||
|
||||
# Overloads to preserve input type
|
||||
@overload
|
||||
def encrypt_secret_keys(
|
||||
obj: Mapping[str, Any],
|
||||
secret_variables: set[str] | None = None,
|
||||
parent_key: str | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def encrypt_secret_keys(
|
||||
obj: list[Any],
|
||||
secret_variables: set[str] | None = None,
|
||||
parent_key: str | None = None,
|
||||
) -> list[Any]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def encrypt_secret_keys(
|
||||
obj: Any,
|
||||
secret_variables: set[str] | None = None,
|
||||
parent_key: str | None = None,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
def encrypt_secret_keys(
|
||||
obj: Any,
|
||||
secret_variables: set[str] | None = None,
|
||||
parent_key: str | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Recursively obfuscate the value if it belongs to a Secret Variable.
|
||||
Preserves input type: dict -> dict, list -> list, scalar -> scalar.
|
||||
"""
|
||||
if secret_variables is None:
|
||||
secret_variables = set()
|
||||
|
||||
if isinstance(obj, Mapping):
|
||||
# recurse into dict
|
||||
return {key: encrypt_secret_keys(value, secret_variables, key) for key, value in obj.items()}
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# recurse into all list elements
|
||||
return [encrypt_secret_keys(value, secret_variables, None) for value in obj]
|
||||
|
||||
else:
|
||||
# leaf node: obfuscate if parent_key is a secret variable
|
||||
if parent_key in secret_variables:
|
||||
return obfuscated_token(str(obj))
|
||||
return obj
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
|||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.helper import encrypter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
|
|
@ -24,6 +25,7 @@ from core.tools.entities.tool_entities import (
|
|||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables import SecretVariable
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
|
|
@ -115,6 +117,19 @@ class AgentNode(Node[AgentNodeData]):
|
|||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
# to store secret variables used in the Agent block.
|
||||
secret_variables = set()
|
||||
# get secret variables used.
|
||||
for section_vars in self.graph_runtime_state.variable_pool.variable_dictionary.values():
|
||||
# Iterate over all the sections. e.g. sys, env etc.
|
||||
if isinstance(section_vars, dict):
|
||||
# Iterate over each variable in the section
|
||||
for variable in section_vars.values():
|
||||
# Check if the variable is a SecretVariable
|
||||
if isinstance(variable, SecretVariable):
|
||||
# Add the variable name to the set
|
||||
secret_variables.add(variable.name)
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
|
|
@ -147,6 +162,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
secret_variables=secret_variables,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
|
|
@ -467,6 +483,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
secret_variables: set[str] | None = None,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
|
|
@ -650,9 +667,9 @@ class AgentNode(Node[AgentNodeData]):
|
|||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
data=encrypter.encrypt_secret_keys(message.message.data, secret_variables, None),
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
metadata=encrypter.encrypt_secret_keys(message.message.metadata, secret_variables, None),
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ from decimal import Decimal
|
|||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables import SecretVariable
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
|
|
@ -49,15 +51,24 @@ class CodeNode(Node[CodeNodeData]):
|
|||
code_language = self.node_data.code_language
|
||||
code = self.node_data.code
|
||||
|
||||
# to store secret variables used in the code block.
|
||||
secret_variables = set()
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self.node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, SecretVariable):
|
||||
secret_variables.add(variable_name)
|
||||
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None
|
||||
else:
|
||||
variables[variable_name] = variable.to_object() if variable else None
|
||||
|
||||
obfuscated_variables = encrypter.encrypt_secret_keys(variables, secret_variables, None)
|
||||
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
|
|
@ -70,10 +81,13 @@ class CodeNode(Node[CodeNodeData]):
|
|||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=obfuscated_variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=obfuscated_variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import pytest
|
|||
from core.helper.encrypter import (
|
||||
batch_decrypt_token,
|
||||
decrypt_token,
|
||||
encrypt_secret_keys,
|
||||
encrypt_token,
|
||||
get_decrypt_decoding,
|
||||
obfuscated_token,
|
||||
|
|
@ -36,6 +37,47 @@ class TestObfuscatedToken:
|
|||
assert token not in obfuscated
|
||||
assert "*" * 12 in obfuscated
|
||||
|
||||
def test_encrypt_secret_keys_simple_dict(self):
|
||||
data = {"api_key": "fake-secret-key", "username": "admin"}
|
||||
secret_vars = {"api_key"}
|
||||
|
||||
result = encrypt_secret_keys(data, secret_vars)
|
||||
|
||||
# api_key should be obfuscated
|
||||
assert result["api_key"] == obfuscated_token("fake-secret-key")
|
||||
# username should remain unchanged
|
||||
assert result["username"] == "admin"
|
||||
|
||||
def test_encrypt_secret_keys_nested_dict(self):
|
||||
data = {"outer": {"inner_secret": "super-secret", "inner_public": "visible"}, "non_secret": "plain"}
|
||||
secret_vars = {"inner_secret"}
|
||||
|
||||
result = encrypt_secret_keys(data, secret_vars)
|
||||
|
||||
assert result["outer"]["inner_secret"] == obfuscated_token("super-secret")
|
||||
assert result["outer"]["inner_public"] == "visible"
|
||||
assert result["non_secret"] == "plain"
|
||||
|
||||
def test_encrypt_secret_keys_list_of_dicts(self):
|
||||
data = [{"token1": "abc123", "id": 1}, {"token2": "xyz789", "id": 2}]
|
||||
secret_vars = {"token1", "token2"}
|
||||
|
||||
result = encrypt_secret_keys(data, secret_vars)
|
||||
|
||||
assert result[0]["token1"] == obfuscated_token("abc123")
|
||||
assert result[1]["token2"] == obfuscated_token("xyz789")
|
||||
assert result[0]["id"] == 1
|
||||
|
||||
def test_encrypt_secret_keys_non_secret_scalar(self):
|
||||
# When the object is just a string, it should remain unchanged
|
||||
result = encrypt_secret_keys("hello-world", secret_variables={"api_key"})
|
||||
assert result == "hello-world"
|
||||
|
||||
def test_encrypt_secret_keys_handles_empty_inputs(self):
|
||||
assert encrypt_secret_keys({}, {"secret"}) == {}
|
||||
assert encrypt_secret_keys([], {"secret"}) == []
|
||||
assert encrypt_secret_keys(None, {"secret"}) is None
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
|
|
|
|||
Loading…
Reference in New Issue