mirror of https://github.com/langgenius/dify.git
add recommended rag plugin endpoint
This commit is contained in:
parent
826b9d5b21
commit
56fc9088dd
|
|
@ -107,7 +107,7 @@ class DatasourceFileManager:
|
|||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> UploadFile:
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
|
|
@ -127,26 +127,22 @@ class DatasourceFileManager:
|
|||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
upload_file = UploadFile(
|
||||
tool_file = ToolFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=filepath,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(blob).hexdigest(),
|
||||
source_url=file_url,
|
||||
created_at=datetime.now(),
|
||||
key=filepath,
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ from collections.abc import Generator
|
|||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional
|
||||
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,15 +31,15 @@ class DatasourceFileMessageTransformer:
|
|||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
file = DatasourceFileManager.create_file_by_url(
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
|
||||
url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
|
|
@ -71,7 +71,8 @@ class DatasourceFileMessageTransformer:
|
|||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
file = DatasourceFileManager.create_file_by_raw(
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
|
|
@ -80,7 +81,7 @@ class DatasourceFileMessageTransformer:
|
|||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from core.workflow.nodes.tool.exc import ToolFileError
|
|||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
|
|
@ -376,19 +377,19 @@ class DatasourceNode(Node):
|
|||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.DATASOURCE_FILE
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"datasource_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type),
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
|
|
@ -404,14 +405,14 @@ class DatasourceNode(Node):
|
|||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"datasource_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.DATASOURCE_FILE,
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
|
|
@ -513,19 +514,19 @@ class DatasourceNode(Node):
|
|||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.DATASOURCE_FILE
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"datasource_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type),
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1162,7 +1162,7 @@ class RagPipelineService:
|
|||
)
|
||||
return node_exec
|
||||
|
||||
def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account | EndUser):
|
||||
def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account):
|
||||
"""
|
||||
Set datasource variables
|
||||
"""
|
||||
|
|
@ -1225,7 +1225,7 @@ class RagPipelineService:
|
|||
repository.save(workflow_node_execution)
|
||||
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution_db_model = repository.to_db_model(workflow_node_execution)
|
||||
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
|
|
@ -1235,6 +1235,7 @@ class RagPipelineService:
|
|||
node_type=NodeType(workflow_node_execution_db_model.node_type),
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
node_execution_id=workflow_node_execution.id,
|
||||
user=current_user,
|
||||
)
|
||||
draft_var_saver.save(
|
||||
process_data=workflow_node_execution.process_data,
|
||||
|
|
|
|||
Loading…
Reference in New Issue