diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index af50a58212..66d69499e7 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -107,7 +107,7 @@ class DatasourceFileManager: tenant_id: str, file_url: str, conversation_id: Optional[str] = None, - ) -> UploadFile: + ) -> ToolFile: # try to download image try: response = ssrf_proxy.get(file_url) @@ -127,26 +127,22 @@ class DatasourceFileManager: filepath = f"tools/{tenant_id}/{filename}" storage.save(filepath, blob) - upload_file = UploadFile( + tool_file = ToolFile( tenant_id=tenant_id, - storage_type=dify_config.STORAGE_TYPE, - key=filepath, + user_id=user_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + original_url=file_url, name=filename, size=len(blob), - extension=extension, - mime_type=mimetype, - created_by_role=CreatorUserRole.ACCOUNT, - created_by=user_id, - used=False, - hash=hashlib.sha3_256(blob).hexdigest(), - source_url=file_url, - created_at=datetime.now(), + key=filepath, ) - db.session.add(upload_file) + db.session.add(tool_file) db.session.commit() - return upload_file + return tool_file @staticmethod def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 6c93865264..d249e02064 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -3,9 +3,9 @@ from collections.abc import Generator from mimetypes import guess_extension, guess_type from typing import Optional -from core.datasource.datasource_file_manager import DatasourceFileManager from core.datasource.entities.datasource_entities import DatasourceMessage from core.file import File, FileTransferMethod, FileType +from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) @@ -31,15 +31,15 @@ class DatasourceFileMessageTransformer: # try to download image try: assert isinstance(message.message, DatasourceMessage.TextMessage) - - file = DatasourceFileManager.create_file_by_url( + tool_file_manager = ToolFileManager() + file = tool_file_manager.create_file_by_url( user_id=user_id, tenant_id=tenant_id, file_url=message.message.text, conversation_id=conversation_id, ) - url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}" + url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" yield DatasourceMessage( type=DatasourceMessage.MessageType.IMAGE_LINK, @@ -71,7 +71,8 @@ class DatasourceFileMessageTransformer: # FIXME: should do a type check here. assert isinstance(message.message.blob, bytes) - file = DatasourceFileManager.create_file_by_raw( + tool_file_manager = ToolFileManager() + file = tool_file_manager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, @@ -80,7 +81,7 @@ class DatasourceFileMessageTransformer: filename=filename, ) - url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type)) + url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) # check if file is image if "image" in mimetype: diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 824ff2b600..de2d03975a 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -30,6 +30,7 @@ from core.workflow.nodes.tool.exc import ToolFileError from extensions.ext_database import db from factories import file_factory from models.model import UploadFile +from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey @@ -376,19 +377,19 @@ class DatasourceNode(Node): assert isinstance(message.message, DatasourceMessage.TextMessage) url = message.message.text - transfer_method = FileTransferMethod.DATASOURCE_FILE + transfer_method = FileTransferMethod.TOOL_FILE datasource_file_id = str(url).split("/")[-1].split(".")[0] with Session(db.engine) as session: - stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) datasource_file = session.scalar(stmt) if datasource_file is None: raise ToolFileError(f"Tool file {datasource_file_id} does not exist") mapping = { - "datasource_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type), + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), "transfer_method": transfer_method, "url": url, } @@ -404,14 +405,14 @@ class DatasourceNode(Node): datasource_file_id = message.message.text.split("/")[-1].split(".")[0] with Session(db.engine) as session: - stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) datasource_file = session.scalar(stmt) if datasource_file is None: raise ToolFileError(f"datasource file {datasource_file_id} not exists") mapping = { - "datasource_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.DATASOURCE_FILE, + "tool_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, } files.append( @@ -513,19 +514,19 @@ class DatasourceNode(Node): assert isinstance(message.message, DatasourceMessage.TextMessage) url = message.message.text - transfer_method = FileTransferMethod.DATASOURCE_FILE + transfer_method = FileTransferMethod.TOOL_FILE datasource_file_id = str(url).split("/")[-1].split(".")[0] with Session(db.engine) as session: - stmt = select(UploadFile).where(UploadFile.id == datasource_file_id) + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) datasource_file = session.scalar(stmt) if datasource_file is None: raise ToolFileError(f"Tool file {datasource_file_id} does not exist") mapping = { - "datasource_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type), + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), "transfer_method": transfer_method, "url": url, } diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 9b0273e67e..795e2d5901 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -1162,7 +1162,7 @@ class RagPipelineService: ) return node_exec - def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser): + def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account): """ Set datasource variables """ @@ -1225,7 +1225,7 @@ class RagPipelineService: repository.save(workflow_node_execution) # Convert node_execution to WorkflowNodeExecution after save - workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution) + workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( @@ -1235,6 +1235,7 @@ class RagPipelineService: node_type=NodeType(workflow_node_execution_db_model.node_type), enclosing_node_id=enclosing_node_id, node_execution_id=workflow_node_execution.id, + user=current_user, ) draft_var_saver.save( process_data=workflow_node_execution.process_data,