mirror of https://github.com/langgenius/dify.git
refactor: port api/controllers/console/datasets/data_source.py /datasets/metadata.py /service_api/dataset/metadata.py /nodes/agent/agent_node.py api/core/workflow/nodes/datasource/datasource_node.py api/services/dataset_service.py to match case (#31836)
This commit is contained in:
parent
ce2c41bbf5
commit
491fa9923b
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
|
|
@ -157,9 +157,8 @@ class DataSourceApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action):
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
|
|
@ -167,23 +166,24 @@ class DataSourceApi(Resource):
|
|||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
if action == "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
if action == "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
match action:
|
||||
case "enable":
|
||||
if data_source_binding.disabled:
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is not disabled.")
|
||||
# disable binding
|
||||
case "disable":
|
||||
if not data_source_binding.disabled:
|
||||
data_source_binding.disabled = True
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.add(data_source_binding)
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source is disabled.")
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
|||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
|||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
if action == "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
elif action == "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]):
|
|||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
|
|
@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]):
|
|||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
|
|
|||
|
|
@ -270,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
|||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
input = typed_node_data.datasource_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
match input.type:
|
||||
case "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
case "constant":
|
||||
pass
|
||||
case None:
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
|
|
@ -308,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
|||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
match message.type:
|
||||
case (
|
||||
DatasourceMessage.MessageType.IMAGE_LINK
|
||||
| DatasourceMessage.MessageType.BINARY_LINK
|
||||
| DatasourceMessage.MessageType.IMAGE
|
||||
):
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
files.append(file)
|
||||
case DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
case DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
case DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
case DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case (
|
||||
DatasourceMessage.MessageType.BLOB_CHUNK
|
||||
| DatasourceMessage.MessageType.LOG
|
||||
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
|
||||
):
|
||||
pass
|
||||
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
|
|
|
|||
|
|
@ -2978,14 +2978,15 @@ class DocumentService:
|
|||
"""
|
||||
now = naive_utc_now()
|
||||
|
||||
if action == "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
elif action == "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
elif action == "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
elif action == "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
match action:
|
||||
case "enable":
|
||||
return DocumentService._prepare_enable_update(document, now)
|
||||
case "disable":
|
||||
return DocumentService._prepare_disable_update(document, user, now)
|
||||
case "archive":
|
||||
return DocumentService._prepare_archive_update(document, user, now)
|
||||
case "un_archive":
|
||||
return DocumentService._prepare_unarchive_update(document, now)
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -3622,56 +3623,57 @@ class SegmentService:
|
|||
# Check if segment_ids is not empty to avoid WHERE false condition
|
||||
if not segment_ids or len(segment_ids) == 0:
|
||||
return
|
||||
if action == "enable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
match action:
|
||||
case "enable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == False,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
elif action == "disable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
case "disable":
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
return
|
||||
real_deal_segment_ids = []
|
||||
for segment in segments:
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
continue
|
||||
segment.enabled = False
|
||||
segment.disabled_at = naive_utc_now()
|
||||
segment.disabled_by = current_user.id
|
||||
db.session.add(segment)
|
||||
real_deal_segment_ids.append(segment.id)
|
||||
db.session.commit()
|
||||
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
|
||||
|
||||
@classmethod
|
||||
def create_child_chunk(
|
||||
|
|
|
|||
Loading…
Reference in New Issue