mirror of https://github.com/langgenius/dify.git
add image file as markdown stream outupt
This commit is contained in:
parent
d8ab611480
commit
80f1fbba56
|
|
@ -27,7 +27,7 @@ class ToolFilePreviewApi(Resource):
|
|||
raise Forbidden('Invalid request.')
|
||||
|
||||
try:
|
||||
result = ToolFileManager.get_file_generator_by_message_file_id(
|
||||
result = ToolFileManager.get_file_generator_by_tool_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class MessageListApi(Resource):
|
|||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'answer': fields.String(attribute='re_sign_file_url_answer'),
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class MessageListApi(WebApiResource):
|
|||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'answer': fields.String(attribute='re_sign_file_url_answer'),
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ class TextToSpeechEntity(BaseModel):
|
|||
language: Optional[str] = None
|
||||
|
||||
|
||||
class FileUploadEntity(BaseModel):
|
||||
class FileExtraConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
|
@ -191,7 +191,7 @@ class FileUploadEntity(BaseModel):
|
|||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
file_upload: Optional[FileUploadEntity] = None
|
||||
file_upload: Optional[FileExtraConfig] = None
|
||||
opening_statement: Optional[str] = None
|
||||
suggested_questions: list[str] = []
|
||||
suggested_questions_after_answer: bool = False
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import FileUploadEntity
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> Optional[FileUploadEntity]:
|
||||
def convert(cls, config: dict) -> Optional[FileExtraConfig]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ class FileUploadConfigManager:
|
|||
if file_upload_dict:
|
||||
if 'image' in file_upload_dict and file_upload_dict['image']:
|
||||
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||
return FileUploadEntity(
|
||||
return FileExtraConfig(
|
||||
image_config={
|
||||
'number_limits': file_upload_dict['image']['number_limits'],
|
||||
'detail': file_upload_dict['image']['detail'],
|
||||
|
|
|
|||
|
|
@ -67,11 +67,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
# parse files
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict)
|
||||
if file_upload_entity:
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_upload_entity,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
|
@ -11,7 +12,6 @@ from core.app.entities.queue_entities import (
|
|||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
|
|
@ -34,6 +34,7 @@ from core.app.entities.task_entities import (
|
|||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
|
|
@ -260,10 +261,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
annotation = self._handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.answer = annotation.content
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
# elif isinstance(event, QueueMessageFileEvent):
|
||||
# response = self._message_file_to_stream_response(event)
|
||||
# if response:
|
||||
# yield response
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
|
|
@ -464,10 +465,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
text = None
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, object): # TODO FILE
|
||||
# convert file to markdown
|
||||
text = f'})'
|
||||
pass
|
||||
elif isinstance(value, dict | list):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
for file_var in file_vars:
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
continue
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown()
|
||||
|
||||
if not text:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if text:
|
||||
for token in text:
|
||||
|
|
|
|||
|
|
@ -81,11 +81,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
# parse files
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_upload_entity:
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_upload_entity,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
|
|||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.file_obj import FileObj
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||
|
|
@ -33,7 +33,7 @@ class AppRunner:
|
|||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
query: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
|
|
@ -125,7 +125,7 @@ class AppRunner:
|
|||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
|
|
|
|||
|
|
@ -81,11 +81,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
# parse files
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_upload_entity:
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_upload_entity,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -76,11 +76,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
# parse files
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_upload_entity:
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_upload_entity,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
else:
|
||||
|
|
@ -233,11 +233,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_upload_entity = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_upload_entity:
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
message.files,
|
||||
file_upload_entity,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
transfer_method=file.transfer_method.value,
|
||||
belongs_to='user',
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=('account' if account_id else 'end_user'),
|
||||
created_by=account_id or end_user_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -50,11 +50,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
# parse files
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict)
|
||||
if file_upload_entity:
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_upload_entity,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileObj
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ class AppGenerateEntity(BaseModel):
|
|||
app_config: AppConfig
|
||||
|
||||
inputs: dict[str, str]
|
||||
files: list[FileObj] = []
|
||||
files: list[FileVar] = []
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
|
|
|
|||
|
|
@ -204,6 +204,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||
total_steps: int
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
workflow_run_id: str
|
||||
|
|
@ -253,6 +254,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
|
|
|
|||
|
|
@ -97,6 +97,11 @@ class MessageCycleManage:
|
|||
)
|
||||
|
||||
if message_file:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
|
|
@ -105,7 +110,7 @@ class MessageCycleManage:
|
|||
else:
|
||||
extension = '.bin'
|
||||
# add sign url
|
||||
url = ToolFileManager.sign_file(file_id=message_file.id, extension=extension)
|
||||
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
|
||||
|
||||
return MessageFileStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.app.entities.task_entities import (
|
|||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -93,7 +94,7 @@ class WorkflowCycleManage:
|
|||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[dict] = None) -> WorkflowRun:
|
||||
outputs: Optional[str] = None) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
|
|
@ -244,7 +245,8 @@ class WorkflowCycleManage:
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_start_to_stream_response(self, task_id: str, workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
|
||||
def _workflow_start_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
|
||||
"""
|
||||
Workflow start to stream response.
|
||||
:param task_id: task id
|
||||
|
|
@ -262,7 +264,8 @@ class WorkflowCycleManage:
|
|||
)
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(self, task_id: str, workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
|
||||
def _workflow_finish_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
|
||||
"""
|
||||
Workflow finish to stream response.
|
||||
:param task_id: task id
|
||||
|
|
@ -283,7 +286,8 @@ class WorkflowCycleManage:
|
|||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp())
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -310,7 +314,7 @@ class WorkflowCycleManage:
|
|||
)
|
||||
|
||||
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeFinishStreamResponse:
|
||||
-> NodeFinishStreamResponse:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param task_id: task id
|
||||
|
|
@ -334,7 +338,8 @@ class WorkflowCycleManage:
|
|||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp())
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -465,3 +470,48 @@ class WorkflowCycleManage:
|
|||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
:return:
|
||||
"""
|
||||
files = []
|
||||
for output_var, output_value in outputs_dict.items():
|
||||
file_vars = self._fetch_files_from_variable_value(output_value)
|
||||
if file_vars:
|
||||
files.extend(file_vars)
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = self._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
elif isinstance(value, dict):
|
||||
file_var = self._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
|
||||
return files
|
||||
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__class__.__name__:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from typing import Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.app_config.entities import FileUploadEntity
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -44,27 +45,65 @@ class FileBelongsTo(enum.Enum):
|
|||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
class FileObj(BaseModel):
|
||||
id: Optional[str]
|
||||
|
||||
class FileVar(BaseModel):
|
||||
id: Optional[str] = None # message file id
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
url: Optional[str]
|
||||
upload_file_id: Optional[str]
|
||||
file_upload_entity: FileUploadEntity
|
||||
url: Optional[str] = None # remote url
|
||||
related_id: Optional[str] = None
|
||||
extra_config: Optional[FileExtraConfig] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'__variant': self.__class__.__name__,
|
||||
'type': self.type.value,
|
||||
'transfer_method': self.transfer_method.value,
|
||||
'url': self.preview_url,
|
||||
'related_id': self.related_id,
|
||||
'filename': self.filename,
|
||||
'extension': self.extension,
|
||||
'mime_type': self.mime_type,
|
||||
}
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
"""
|
||||
Convert file to markdown
|
||||
:return:
|
||||
"""
|
||||
preview_url = self.preview_url
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
else:
|
||||
text = f'[{self.filename or self.preview_url}]({self.preview_url})'
|
||||
|
||||
return text
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[str]:
|
||||
"""
|
||||
Get image data, file signed url or base64 data
|
||||
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
|
||||
:return:
|
||||
"""
|
||||
return self._get_data()
|
||||
|
||||
@property
|
||||
def preview_url(self) -> Optional[str]:
|
||||
"""
|
||||
Get signed preview url
|
||||
:return:
|
||||
"""
|
||||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.file_upload_entity.image_config
|
||||
image_config = self.extra_config.image_config
|
||||
|
||||
return ImagePromptMessageContent(
|
||||
data=self.data,
|
||||
|
|
@ -79,7 +118,7 @@ class FileObj(BaseModel):
|
|||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == self.upload_file_id,
|
||||
UploadFile.id == self.related_id,
|
||||
UploadFile.tenant_id == self.tenant_id
|
||||
).first())
|
||||
|
||||
|
|
@ -87,5 +126,15 @@ class FileObj(BaseModel):
|
|||
upload_file=upload_file,
|
||||
force_url=force_url
|
||||
)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# get extension
|
||||
if '.' in self.url:
|
||||
extension = f'.{self.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
else:
|
||||
extension = '.bin'
|
||||
# add sign url
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from typing import Union
|
|||
|
||||
import requests
|
||||
|
||||
from core.app.app_config.entities import FileUploadEntity
|
||||
from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.file.file_obj import FileBelongsTo, FileTransferMethod, FileType, FileVar
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, MessageFile, UploadFile
|
||||
|
|
@ -16,13 +16,13 @@ class MessageFileParser:
|
|||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: list[dict], file_upload_entity: FileUploadEntity,
|
||||
user: Union[Account, EndUser]) -> list[FileObj]:
|
||||
def validate_and_transform_files_arg(self, files: list[dict], file_extra_config: FileExtraConfig,
|
||||
user: Union[Account, EndUser]) -> list[FileVar]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
:param files:
|
||||
:param file_upload_entity:
|
||||
:param file_extra_config:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
|
|
@ -44,14 +44,14 @@ class MessageFileParser:
|
|||
raise ValueError('Missing file upload_file_id')
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_upload_entity)
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
||||
# validate files
|
||||
new_files = []
|
||||
for file_type, file_objs in type_file_objs.items():
|
||||
if file_type == FileType.IMAGE:
|
||||
# parse and validate files
|
||||
image_config = file_upload_entity.image_config
|
||||
image_config = file_extra_config.image_config
|
||||
|
||||
# check if image file feature is enabled
|
||||
if not image_config:
|
||||
|
|
@ -79,7 +79,7 @@ class MessageFileParser:
|
|||
# get upload file from upload_file_id
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.upload_file_id,
|
||||
UploadFile.id == file_obj.related_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
|
|
@ -95,30 +95,30 @@ class MessageFileParser:
|
|||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: list[MessageFile], file_upload_entity: FileUploadEntity) -> list[FileObj]:
|
||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
|
||||
"""
|
||||
transform message files
|
||||
|
||||
:param files:
|
||||
:param file_upload_entity:
|
||||
:param file_extra_config:
|
||||
:return:
|
||||
"""
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_upload_entity)
|
||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
||||
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: list[Union[dict, MessageFile]],
|
||||
file_upload_entity: FileUploadEntity) -> dict[FileType, list[FileObj]]:
|
||||
file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
:param files:
|
||||
:param file_upload_entity:
|
||||
:param file_extra_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: dict[FileType, list[FileObj]] = {
|
||||
type_file_objs: dict[FileType, list[FileVar]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
|
@ -132,7 +132,7 @@ class MessageFileParser:
|
|||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
||||
continue
|
||||
|
||||
file_obj = self._to_file_obj(file, file_upload_entity)
|
||||
file_obj = self._to_file_obj(file, file_extra_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
|
|
@ -140,7 +140,7 @@ class MessageFileParser:
|
|||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_entity: FileUploadEntity) -> FileObj:
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
|
|
@ -149,23 +149,23 @@ class MessageFileParser:
|
|||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
return FileObj(
|
||||
return FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
file_upload_entity=file_upload_entity
|
||||
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=file_extra_config
|
||||
)
|
||||
else:
|
||||
return FileObj(
|
||||
return FileVar(
|
||||
id=file.id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.type),
|
||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id or None,
|
||||
file_upload_entity=file_upload_entity
|
||||
related_id=file.upload_file_id or None,
|
||||
extra_config=file_extra_config
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from extensions.ext_storage import storage
|
|||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
|
|
@ -23,7 +24,7 @@ class UploadFileParser:
|
|||
return None
|
||||
|
||||
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file)
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
|
|
@ -36,7 +37,7 @@ class UploadFileParser:
|
|||
return f'data:{upload_file.mime_type};base64,{encoded_string}'
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file) -> str:
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
|
|
@ -44,11 +45,11 @@ class UploadFileParser:
|
|||
:return:
|
||||
"""
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
|
||||
image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview'
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
|
|
|||
|
|
@ -45,14 +45,14 @@ class TokenBufferMemory:
|
|||
files = message.message_files
|
||||
if files:
|
||||
if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
file_upload_entity = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
else:
|
||||
file_upload_entity = FileUploadConfigManager.convert(message.workflow_run.workflow.features_dict)
|
||||
file_extra_config = FileUploadConfigManager.convert(message.workflow_run.workflow.features_dict)
|
||||
|
||||
if file_upload_entity:
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files,
|
||||
file_upload_entity
|
||||
file_extra_config
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileObj
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -25,7 +25,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
|
|
@ -62,7 +62,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
prompt_template: CompletionModelPromptTemplate,
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
|
|
@ -113,7 +113,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
prompt_template: list[ChatModelMessage],
|
||||
inputs: dict,
|
||||
query: Optional[str],
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileObj
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
|
|
@ -50,7 +50,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||
|
|
@ -161,7 +161,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
|
|
@ -204,7 +204,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
|
|
@ -253,7 +253,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
|
||||
return [self.get_last_user_message(prompt, files)], stops
|
||||
|
||||
def get_last_user_message(self, prompt: str, files: list[FileObj]) -> UserPromptMessage:
|
||||
def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
for file in files:
|
||||
|
|
|
|||
|
|
@ -21,16 +21,16 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class ToolFileManager:
|
||||
@staticmethod
|
||||
def sign_file(file_id: str, extension: str) -> str:
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
file_preview_url = f'{base_url}/files/tools/{file_id}{extension}'
|
||||
file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}'
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
|
@ -163,23 +163,14 @@ class ToolFileManager:
|
|||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_message_file_id(id: str) -> Union[tuple[Generator, str], None]:
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
:param tool_file_id: the id of the tool file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile = db.session.query(MessageFile).filter(
|
||||
MessageFile.id == id,
|
||||
).first()
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
|
||||
tool_file: ToolFile = db.session.query(ToolFile).filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
).first()
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list]
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
||||
class ValueType(Enum):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file.file_obj import FileObj
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
|
@ -51,15 +51,10 @@ class LLMNode(BaseNode):
|
|||
}
|
||||
|
||||
# fetch files
|
||||
files: list[FileObj] = self._fetch_files(node_data, variable_pool)
|
||||
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
|
||||
|
||||
if files:
|
||||
node_inputs['#files#'] = [{
|
||||
'type': file.type.value,
|
||||
'transfer_method': file.transfer_method.value,
|
||||
'url': file.url,
|
||||
'upload_file_id': file.upload_file_id,
|
||||
} for file in files]
|
||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
context = self._fetch_context(node_data, variable_pool)
|
||||
|
|
@ -202,7 +197,7 @@ class LLMNode(BaseNode):
|
|||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]:
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
||||
"""
|
||||
Fetch files
|
||||
:param node_data: node data
|
||||
|
|
@ -350,7 +345,7 @@ class LLMNode(BaseNode):
|
|||
|
||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from os import path
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
|
|
@ -58,19 +58,19 @@ class ToolNode(BaseNode):
|
|||
},
|
||||
inputs=parameters
|
||||
)
|
||||
|
||||
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
|
||||
"""
|
||||
Generate parameters
|
||||
"""
|
||||
return {
|
||||
k.variable:
|
||||
k.value if k.variable_type == 'static' else
|
||||
k.variable:
|
||||
k.value if k.variable_type == 'static' else
|
||||
variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else ''
|
||||
for k in node_data.tool_parameters
|
||||
}
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]:
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
|
|
@ -87,7 +87,7 @@ class ToolNode(BaseNode):
|
|||
|
||||
return plain_text, files
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
|
|
@ -95,46 +95,50 @@ class ToolNode(BaseNode):
|
|||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
url = response.message
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get('mime_type', 'image/jpeg')
|
||||
filename = response.save_as or url.split('/')[-1]
|
||||
result.append({
|
||||
'type': 'image',
|
||||
'transfer_method': FileTransferMethod.TOOL_FILE,
|
||||
'url': url,
|
||||
'upload_file_id': None,
|
||||
'filename': filename,
|
||||
'file-ext': ext,
|
||||
'mime-type': mimetype,
|
||||
})
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = url.split('/')[-1]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append({
|
||||
'type': 'image', # TODO: only support image for now
|
||||
'transfer_method': FileTransferMethod.TOOL_FILE,
|
||||
'url': response.message,
|
||||
'upload_file_id': None,
|
||||
'filename': response.save_as,
|
||||
'file-ext': path.splitext(response.save_as)[1],
|
||||
'mime-type': response.meta.get('mime_type', 'application/octet-stream'),
|
||||
})
|
||||
# get tool file id
|
||||
tool_file_id = response.message.split('/')[-1]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
pass # TODO:
|
||||
pass # TODO:
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Extract tool response text
|
||||
"""
|
||||
return ''.join([
|
||||
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||
for message in tool_response
|
||||
])
|
||||
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ message_detail_fields = {
|
|||
'query': fields.String,
|
||||
'message': fields.Raw,
|
||||
'message_tokens': fields.Integer,
|
||||
'answer': fields.String,
|
||||
'answer': fields.String(attribute='re_sign_file_url_answer'),
|
||||
'answer_tokens': fields.Integer,
|
||||
'provider_response_latency': fields.Float,
|
||||
'from_source': fields.String,
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ message_fields = {
|
|||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'answer': fields.String(attribute='re_sign_file_url_answer'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import re
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
|
@ -610,6 +611,71 @@ class Message(db.Model):
|
|||
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
workflow_run_id = db.Column(UUID)
|
||||
|
||||
@property
|
||||
def re_sign_file_url_answer(self) -> str:
|
||||
if not self.answer:
|
||||
return self.answer
|
||||
|
||||
pattern = r'\[!?.*?\]\((((http|https):\/\/[\w.-]+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)'
|
||||
matches = re.findall(pattern, self.answer)
|
||||
|
||||
if not matches:
|
||||
return self.answer
|
||||
|
||||
urls = [match[0] for match in matches]
|
||||
|
||||
# remove duplicate urls
|
||||
urls = list(set(urls))
|
||||
|
||||
if not urls:
|
||||
return self.answer
|
||||
|
||||
re_sign_file_url_answer = self.answer
|
||||
for url in urls:
|
||||
if 'files/tools' in url:
|
||||
# get tool file id
|
||||
tool_file_id_pattern = r'\/files\/tools\/([\.\w-]+)?\?timestamp='
|
||||
result = re.search(tool_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
|
||||
tool_file_id = result.group(1)
|
||||
|
||||
# get extension
|
||||
if '.' in tool_file_id:
|
||||
split_result = tool_file_id.split('.')
|
||||
extension = f'.{split_result[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
tool_file_id = split_result[0]
|
||||
else:
|
||||
extension = '.bin'
|
||||
|
||||
if not tool_file_id:
|
||||
continue
|
||||
|
||||
sign_url = ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=tool_file_id,
|
||||
extension=extension
|
||||
)
|
||||
else:
|
||||
# get upload file id
|
||||
upload_file_id_pattern = r'\/files\/([\w-]+)\/image-preview?\?timestamp='
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
|
||||
upload_file_id = result.group(1)
|
||||
|
||||
if not upload_file_id:
|
||||
continue
|
||||
|
||||
sign_url = UploadFileParser.get_signed_temp_image_url(upload_file_id)
|
||||
|
||||
re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url)
|
||||
|
||||
return re_sign_file_url_answer
|
||||
|
||||
@property
|
||||
def user_feedback(self):
|
||||
feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id,
|
||||
|
|
@ -680,7 +746,7 @@ class Message(db.Model):
|
|||
if message_file.transfer_method == 'local_file':
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == message_file.upload_file_id
|
||||
UploadFile.id == message_file.related_id
|
||||
).first())
|
||||
|
||||
url = UploadFileParser.get_image_data(
|
||||
|
|
@ -688,6 +754,11 @@ class Message(db.Model):
|
|||
force_url=True
|
||||
)
|
||||
if message_file.transfer_method == 'tool_file':
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
|
|
@ -696,7 +767,7 @@ class Message(db.Model):
|
|||
else:
|
||||
extension = '.bin'
|
||||
# add sign url
|
||||
url = ToolFileParser.get_tool_file_manager().sign_file(file_id=message_file.id, extension=extension)
|
||||
url = ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=tool_file_id, extension=extension)
|
||||
|
||||
files.append({
|
||||
'id': message_file.id,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from core.app.app_config.entities import (
|
|||
DatasetRetrieveConfigEntity,
|
||||
EasyUIBasedAppConfig,
|
||||
ExternalDataVariableEntity,
|
||||
FileUploadEntity,
|
||||
FileExtraConfig,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
|
|
@ -416,7 +416,7 @@ class WorkflowConverter:
|
|||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileUploadEntity] = None) -> dict:
|
||||
file_upload: Optional[FileExtraConfig] = None) -> dict:
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param new_app_mode: new app mode
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity
|
||||
from core.file.file_obj import FileObj, FileType, FileTransferMethod
|
||||
from core.app.app_config.entities import ModelConfigEntity, FileExtraConfig
|
||||
from core.file.file_obj import FileVar, FileType, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
|
|
@ -138,13 +138,13 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
|||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = [
|
||||
FileObj(
|
||||
FileVar(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
url="https://example.com/image1.jpg",
|
||||
file_upload_entity=FileUploadEntity(
|
||||
extra_config=FileExtraConfig(
|
||||
image_config={
|
||||
"detail": "high",
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue