diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 690297048e..405d5ed607 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -66,10 +66,17 @@ class MessageFeedbackApi(InstalledAppResource): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args.get("rating"), args.get("content")) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=current_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index bed89a99a5..773ea0e0c6 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -108,7 +108,13 @@ class MessageFeedbackApi(Resource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content")) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=end_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 84c58c62df..ea664b8f1b 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -8,12 +8,16 @@ from werkzeug.exceptions import NotFound import services.dataset_service from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api -from controllers.service_api.app.error import ProviderNotInitializeError +from controllers.service_api.app.error import ( + FileTooLargeError, + NoFileUploadedError, + ProviderNotInitializeError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, DocumentIndexingError, - NoFileUploadedError, - TooManyFilesError, ) from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.errors.error import ProviderTokenNotInitError @@ -238,13 +242,18 @@ class DocumentUpdateByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError - upload_file = FileService.upload_file( - filename=file.filename, - content=file.read(), - mimetype=file.mimetype, - user=current_user, - source="datasets", - ) + try: + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 1f1b2b568e..ed936643dd 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -369,11 +369,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc with Session(db.engine) as session: workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event) - response_finish = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + response_finish = self._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if response_finish: yield response_finish diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2258747a2c..df48a83316 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -114,9 +114,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa } self._task_state = WorkflowTaskState() - self._wip_workflow_node_executions = {} - self._wip_workflow_agent_logs = {} - self.total_tokens: int = 0 + self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 8840737245..1a2e67f7e7 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -67,8 +67,6 @@ class WorkflowCycleManage: _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] _task_state: WorkflowTaskState _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - _wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]] def _handle_workflow_run_start( self, @@ -313,33 +311,11 @@ class WorkflowCycleManage: inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) - execution_metadata_dict = event.execution_metadata - if self._wip_workflow_agent_logs.get(workflow_node_execution.id): - if not execution_metadata_dict: - execution_metadata_dict = {} - - execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get( - workflow_node_execution.id, [] - ) - + execution_metadata_dict = dict(event.execution_metadata or {}) execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( - { - WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, - WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, - WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, - WorkflowNodeExecution.execution_metadata: execution_metadata, - WorkflowNodeExecution.finished_at: finished_at, - WorkflowNodeExecution.elapsed_time: elapsed_time, - } - ) - - db.session.commit() - db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value @@ -372,35 +348,9 @@ class WorkflowCycleManage: outputs = WorkflowEntry.handle_special_values(event.outputs) finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - execution_metadata_dict = event.execution_metadata - if self._wip_workflow_agent_logs.get(workflow_node_execution.id): - if not execution_metadata_dict: - execution_metadata_dict = {} - - execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get( - workflow_node_execution.id, [] - ) - - execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( - { - WorkflowNodeExecution.status: ( - WorkflowNodeExecutionStatus.FAILED.value - if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION.value - ), - WorkflowNodeExecution.error: event.error, - WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, - WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, - WorkflowNodeExecution.finished_at: finished_at, - WorkflowNodeExecution.elapsed_time: elapsed_time, - WorkflowNodeExecution.execution_metadata: execution_metadata, - } + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) - - db.session.commit() - db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = ( WorkflowNodeExecutionStatus.FAILED.value @@ -889,41 +839,10 @@ class WorkflowCycleManage: :param event: agent log event :return: """ - node_execution = self._wip_workflow_node_executions.get(event.node_execution_id) - if not node_execution: - raise Exception(f"Workflow node execution not found: {event.node_execution_id}") - - node_execution_id = node_execution.id - original_agent_logs = self._wip_workflow_agent_logs.get(node_execution_id, []) - - # try to find the log with the same id - for log in original_agent_logs: - if log.id == event.id: - # update the log - log.status = event.status - log.error = event.error - log.data = event.data - break - else: - # append the log - original_agent_logs.append( - AgentLogStreamResponse.Data( - id=event.id, - parent_id=event.parent_id, - node_execution_id=node_execution_id, - error=event.error, - status=event.status, - data=event.data, - label=event.label, - ) - ) - - self._wip_workflow_agent_logs[node_execution_id] = original_agent_logs - return AgentLogStreamResponse( task_id=task_id, data=AgentLogStreamResponse.Data( - node_execution_id=node_execution_id, + node_execution_id=event.node_execution_id, id=event.id, parent_id=event.parent_id, label=event.label, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 937ce1cc5e..da79b1bf03 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -31,7 +31,6 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.rag_web_reader import get_image_upload_file_ids -from core.tools.utils.text_processing_utils import remove_leading_symbols from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py deleted file mode 100644 index 03818741f6..0000000000 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ /dev/null @@ -1,770 +0,0 @@ -import copy -import json -import logging -from collections.abc import Generator, Sequence -from typing import Optional, Union, cast - -import tiktoken -from openai import AzureOpenAI, Stream -from openai.types import Completion -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall - -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageFunction, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS -from core.model_runtime.utils import helper - -logger = logging.getLogger(__name__) - - -class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): - def _invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - base_model_name = self._get_base_model_name(credentials) - ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - - if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: - # chat model - return self._chat_generate( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ) - else: - # text completion model - return self._generate( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stop=stop, - stream=stream, - user=user, - ) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - ) -> int: - base_model_name = self._get_base_model_name(credentials) - model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - if not model_entity: - raise ValueError(f"Base Model Name {base_model_name} is invalid") - model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) - - if model_mode == LLMMode.CHAT.value: - # chat model - return self._num_tokens_from_messages(credentials, prompt_messages, tools) - else: - # text completion model, do not support tool calling - content = prompt_messages[0].content - assert isinstance(content, str) - return self._num_tokens_from_string(credentials, content) - - def validate_credentials(self, model: str, credentials: dict) -> None: - if "openai_api_base" not in credentials: - raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - - if "openai_api_key" not in credentials: - raise CredentialsValidateFailedError("Azure OpenAI API key is required") - - if "base_model_name" not in credentials: - raise CredentialsValidateFailedError("Base Model Name is required") - - base_model_name = self._get_base_model_name(credentials) - ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - - if not ai_model_entity: - raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') - - try: - client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - - if model.startswith("o1"): - client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=1, - max_completion_tokens=20, - stream=False, - ) - elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: - # chat model - client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=0, - max_tokens=20, - stream=False, - ) - else: - # text completion model - client.completions.create( - prompt="ping", - model=model, - temperature=0, - max_tokens=20, - stream=False, - ) - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = self._get_base_model_name(credentials) - ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) - return ai_model_entity.entity if ai_model_entity else None - - def _generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - - extra_model_kwargs = {} - - if stop: - extra_model_kwargs["stop"] = stop - - if user: - extra_model_kwargs["user"] = user - - # text completion model - response = client.completions.create( - prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs - ) - - if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) - - return self._handle_generate_response(model, credentials, response, prompt_messages) - - def _handle_generate_response( - self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] - ): - assistant_text = response.choices[0].text - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=assistant_text) - - # calculate num tokens - if response.usage: - # transform usage - prompt_tokens = response.usage.prompt_tokens - completion_tokens = response.usage.completion_tokens - else: - # calculate num tokens - content = prompt_messages[0].content - assert isinstance(content, str) - prompt_tokens = self._num_tokens_from_string(credentials, content) - completion_tokens = self._num_tokens_from_string(credentials, assistant_text) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - # transform response - result = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage, - system_fingerprint=response.system_fingerprint, - ) - - return result - - def _handle_generate_stream_response( - self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] - ) -> Generator: - full_text = "" - for chunk in response: - if len(chunk.choices) == 0: - continue - - delta = chunk.choices[0] - - if delta.finish_reason is None and (delta.text is None or delta.text == ""): - continue - - # transform assistant message to prompt message - text = delta.text or "" - assistant_prompt_message = AssistantPromptMessage(content=text) - - full_text += text - - if delta.finish_reason is not None: - # calculate num tokens - if chunk.usage: - # transform usage - prompt_tokens = chunk.usage.prompt_tokens - completion_tokens = chunk.usage.completion_tokens - else: - # calculate num tokens - content = prompt_messages[0].content - assert isinstance(content, str) - prompt_tokens = self._num_tokens_from_string(credentials, content) - completion_tokens = self._num_tokens_from_string(credentials, full_text) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( - model=chunk.model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage, - ), - ) - else: - yield LLMResultChunk( - model=chunk.model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - ), - ) - - def _chat_generate( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, - stream: bool = True, - user: Optional[str] = None, - ) -> Union[LLMResult, Generator]: - client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - - response_format = model_parameters.get("response_format") - if response_format: - if response_format == "json_schema": - json_schema = model_parameters.get("json_schema") - if not json_schema: - raise ValueError("Must define JSON Schema when the response format is json_schema") - try: - schema = json.loads(json_schema) - except: - raise ValueError(f"not correct json_schema format: {json_schema}") - model_parameters.pop("json_schema") - model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} - else: - model_parameters["response_format"] = {"type": response_format} - - extra_model_kwargs = {} - - if tools: - extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - - if stop: - extra_model_kwargs["stop"] = stop - - if user: - extra_model_kwargs["user"] = user - - # clear illegal prompt messages - prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) - - block_as_stream = False - if model.startswith("o1"): - if "max_tokens" in model_parameters: - model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] - del model_parameters["max_tokens"] - if stream: - block_as_stream = True - stream = False - - if "stream_options" in extra_model_kwargs: - del extra_model_kwargs["stream_options"] - - if "stop" in extra_model_kwargs: - del extra_model_kwargs["stop"] - - # chat model - response = client.chat.completions.create( - messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs, - ) - - if stream: - return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) - - block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - if block_as_stream: - return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) - - return block_result - - def _handle_chat_block_as_stream_response( - self, - block_result: LLMResult, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Handle llm chat response - - :param model: model name - :param credentials: credentials - :param response: response - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :param stop: stop words - :return: llm response chunk generator - """ - text = block_result.message.content - text = cast(str, text) - - if stop: - text = self.enforce_stop_tokens(text, stop) - - yield LLMResultChunk( - model=block_result.model, - prompt_messages=prompt_messages, - system_fingerprint=block_result.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=text), - finish_reason="stop", - usage=block_result.usage, - ), - ) - - def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: - """ - Clear illegal prompt messages for OpenAI API - - :param model: model name - :param prompt_messages: prompt messages - :return: cleaned prompt messages - """ - checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] - - if model in checklist: - # count how many user messages are there - user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) - if user_message_count > 1: - for prompt_message in prompt_messages: - if isinstance(prompt_message, UserPromptMessage): - if isinstance(prompt_message.content, list): - prompt_message.content = "\n".join( - [ - item.data - if item.type == PromptMessageContentType.TEXT - else "[IMAGE]" - if item.type == PromptMessageContentType.IMAGE - else "" - for item in prompt_message.content - ] - ) - - if model.startswith("o1"): - system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) - if system_message_count > 0: - new_prompt_messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message, SystemPromptMessage): - prompt_message = UserPromptMessage( - content=prompt_message.content, - name=prompt_message.name, - ) - - new_prompt_messages.append(prompt_message) - prompt_messages = new_prompt_messages - - return prompt_messages - - def _handle_chat_generate_response( - self, - model: str, - credentials: dict, - response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - ): - assistant_message = response.choices[0].message - assistant_message_tool_calls = assistant_message.tool_calls - - # extract tool calls from response - tool_calls = [] - self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) - - # calculate num tokens - if response.usage: - # transform usage - prompt_tokens = response.usage.prompt_tokens - completion_tokens = response.usage.completion_tokens - else: - # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - # transform response - result = LLMResult( - model=response.model or model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage, - system_fingerprint=response.system_fingerprint, - ) - - return result - - def _handle_chat_generate_stream_response( - self, - model: str, - credentials: dict, - response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, - ): - index = 0 - full_assistant_content = "" - real_model = model - system_fingerprint = None - completion = "" - tool_calls = [] - for chunk in response: - if len(chunk.choices) == 0: - continue - - delta = chunk.choices[0] - # NOTE: For fix https://github.com/langgenius/dify/issues/5790 - if delta.delta is None: - continue - - # extract tool calls from response - self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) - - # Handling exceptions when content filters' streaming mode is set to asynchronous modified filter - if delta.finish_reason is None and not delta.delta.content: - continue - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - - full_assistant_content += delta.delta.content or "" - - real_model = chunk.model - system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content or "" - - yield LLMResultChunk( - model=real_model, - prompt_messages=prompt_messages, - system_fingerprint=system_fingerprint, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - ), - ) - - index += 1 - - # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - - full_assistant_prompt_message = AssistantPromptMessage(content=completion) - completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( - model=real_model, - prompt_messages=prompt_messages, - system_fingerprint=system_fingerprint, - delta=LLMResultChunkDelta( - index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage - ), - ) - - @staticmethod - def _update_tool_calls( - tool_calls: list[AssistantPromptMessage.ToolCall], - tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]], - ) -> None: - if tool_calls_response: - for response_tool_call in tool_calls_response: - if isinstance(response_tool_call, ChatCompletionMessageToolCall): - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, arguments=response_tool_call.function.arguments - ) - - tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, type=response_tool_call.type, function=function - ) - tool_calls.append(tool_call) - elif isinstance(response_tool_call, ChoiceDeltaToolCall): - index = response_tool_call.index - if index < len(tool_calls): - tool_calls[index].id = response_tool_call.id or tool_calls[index].id - tool_calls[index].type = response_tool_call.type or tool_calls[index].type - if response_tool_call.function: - tool_calls[index].function.name = ( - response_tool_call.function.name or tool_calls[index].function.name - ) - tool_calls[index].function.arguments += response_tool_call.function.arguments or "" - else: - assert response_tool_call.id is not None - assert response_tool_call.type is not None - assert response_tool_call.function is not None - assert response_tool_call.function.name is not None - assert response_tool_call.function.arguments is not None - - function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, arguments=response_tool_call.function.arguments - ) - tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, type=response_tool_call.type, function=function - ) - tool_calls.append(tool_call) - - @staticmethod - def _convert_prompt_message_to_dict(message: PromptMessage): - if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) - if isinstance(message.content, str): - message_dict = {"role": "user", "content": message.content} - else: - sub_messages = [] - assert message.content is not None - for message_content in message.content: - if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = {"type": "text", "text": message_content.data} - sub_messages.append(sub_message_dict) - elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) - sub_message_dict = { - "type": "image_url", - "image_url": {"url": message_content.data, "detail": message_content.detail.value}, - } - sub_messages.append(sub_message_dict) - message_dict = {"role": "user", "content": sub_messages} - elif isinstance(message, AssistantPromptMessage): - # message = cast(AssistantPromptMessage, message) - message_dict = {"role": "assistant", "content": message.content} - if message.tool_calls: - # fix azure when enable json schema cant process content = "" in assistant fix with None - if not message.content: - message_dict["content"] = None - message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls] - elif isinstance(message, SystemPromptMessage): - message = cast(SystemPromptMessage, message) - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, ToolPromptMessage): - message = cast(ToolPromptMessage, message) - message_dict = { - "role": "tool", - "name": message.name, - "content": message.content, - "tool_call_id": message.tool_call_id, - } - else: - raise ValueError(f"Got unknown type {message}") - - if message.name: - message_dict["name"] = message.name - - return message_dict - - def _num_tokens_from_string( - self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None - ) -> int: - try: - encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") - - num_tokens = len(encoding.encode(text)) - - if tools: - num_tokens += self._num_tokens_for_tools(encoding, tools) - - return num_tokens - - def _num_tokens_from_messages( - self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None - ) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - model = credentials["base_model_name"] - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - logger.warning("Warning: model not found. Using cl100k_base encoding.") - model = "cl100k_base" - encoding = tiktoken.get_encoding(model) - - if model.startswith("gpt-35-turbo-0301"): - # every message follows {role/name}\n{content}\n - tokens_per_message = 4 - # if there's a name, the role is omitted - tokens_per_name = -1 - elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or "o1" in model: - tokens_per_message = 3 - tokens_per_name = 1 - else: - raise NotImplementedError( - f"get_num_tokens_from_messages() is not presently implemented " - f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " - "information on how messages are converted to tokens." - ) - num_tokens = 0 - messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] - for message in messages_dict: - num_tokens += tokens_per_message - for key, value in message.items(): - # Cast str(value) in case the message value is not a string - # This occurs with function messages - # TODO: The current token calculation method for the image type is not implemented, - # which need to download the image and then get the resolution for calculation, - # and will increase the request delay - if isinstance(value, list): - text = "" - for item in value: - if isinstance(item, dict) and item["type"] == "text": - text += item["text"] - - value = text - - if key == "tool_calls": - for tool_call in value: - assert isinstance(tool_call, dict) - for t_key, t_value in tool_call.items(): - num_tokens += len(encoding.encode(t_key)) - if t_key == "function": - for f_key, f_value in t_value.items(): - num_tokens += len(encoding.encode(f_key)) - num_tokens += len(encoding.encode(f_value)) - else: - num_tokens += len(encoding.encode(t_key)) - num_tokens += len(encoding.encode(t_value)) - else: - num_tokens += len(encoding.encode(str(value))) - - if key == "name": - num_tokens += tokens_per_name - - # every reply is primed with assistant - num_tokens += 3 - - if tools: - num_tokens += self._num_tokens_for_tools(encoding, tools) - - return num_tokens - - @staticmethod - def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: - num_tokens = 0 - for tool in tools: - num_tokens += len(encoding.encode("type")) - num_tokens += len(encoding.encode("function")) - - # calculate num tokens for function object - num_tokens += len(encoding.encode("name")) - num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode("description")) - num_tokens += len(encoding.encode(tool.description)) - parameters = tool.parameters - num_tokens += len(encoding.encode("parameters")) - if "title" in parameters: - num_tokens += len(encoding.encode("title")) - num_tokens += len(encoding.encode(parameters["title"])) - num_tokens += len(encoding.encode("type")) - num_tokens += len(encoding.encode(parameters["type"])) - if "properties" in parameters: - num_tokens += len(encoding.encode("properties")) - for key, value in parameters["properties"].items(): - num_tokens += len(encoding.encode(key)) - for field_key, field_value in value.items(): - num_tokens += len(encoding.encode(field_key)) - if field_key == "enum": - for enum_field in field_value: - num_tokens += 3 - num_tokens += len(encoding.encode(enum_field)) - else: - num_tokens += len(encoding.encode(field_key)) - num_tokens += len(encoding.encode(str(field_value))) - if "required" in parameters: - num_tokens += len(encoding.encode("required")) - for required_field in parameters["required"]: - num_tokens += 3 - num_tokens += len(encoding.encode(required_field)) - - return num_tokens - - @staticmethod - def _get_ai_model_entity(base_model_name: str, model: str): - for ai_model_entity in LLM_BASE_MODELS: - if ai_model_entity.base_model_name == base_model_name: - ai_model_entity_copy = copy.deepcopy(ai_model_entity) - ai_model_entity_copy.entity.model = model - ai_model_entity_copy.entity.label.en_US = model - ai_model_entity_copy.entity.label.zh_Hans = model - return ai_model_entity_copy - - def _get_base_model_name(self, credentials: dict) -> str: - base_model_name = credentials.get("base_model_name") - if not base_model_name: - raise ValueError("Base Model Name is required") - return base_model_name diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 8b17e8dc0a..a6214d955b 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,5 +1,5 @@ import re -from typing import Optional +from typing import Optional, cast class JiebaKeywordTableHandler: @@ -8,18 +8,20 @@ class JiebaKeywordTableHandler: from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS - jieba.analyse.default_tfidf.stop_words = STOPWORDS + jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" - import jieba # type: ignore + import jieba.analyse # type: ignore keywords = jieba.analyse.extract_tags( sentence=text, topK=max_keywords_per_chunk, ) + # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + keywords = cast(list[str], keywords) - return set(self._expand_tokens_with_subtokens(keywords)) + return set(self._expand_tokens_with_subtokens(set(keywords))) def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py deleted file mode 100644 index 4094207beb..0000000000 --- a/api/core/tools/tool/tool.py +++ /dev/null @@ -1,355 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Mapping -from copy import deepcopy -from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union - -from pydantic import BaseModel, ConfigDict, field_validator -from pydantic_core.core_schema import ValidationInfo - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ( - ToolDescription, - ToolIdentity, - ToolInvokeFrom, - ToolInvokeMessage, - ToolParameter, - ToolProviderType, - ToolRuntimeImageVariable, - ToolRuntimeVariable, - ToolRuntimeVariablePool, -) -from core.tools.tool_file_manager import ToolFileManager - -if TYPE_CHECKING: - from core.file.models import File - - -class Tool(BaseModel, ABC): - identity: Optional[ToolIdentity] = None - parameters: Optional[list[ToolParameter]] = None - description: Optional[ToolDescription] = None - is_team_authorization: bool = False - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - @field_validator("parameters", mode="before") - @classmethod - def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: - return v or [] - - class Runtime(BaseModel): - """ - Meta data of a tool call processing - """ - - def __init__(self, **data: Any): - super().__init__(**data) - if not self.runtime_parameters: - self.runtime_parameters = {} - - tenant_id: Optional[str] = None - tool_id: Optional[str] = None - invoke_from: Optional[InvokeFrom] = None - tool_invoke_from: Optional[ToolInvokeFrom] = None - credentials: Optional[dict[str, Any]] = None - runtime_parameters: Optional[dict[str, Any]] = None - - runtime: Optional[Runtime] = None - variables: Optional[ToolRuntimeVariablePool] = None - - def __init__(self, **data: Any): - super().__init__(**data) - - class VariableKey(StrEnum): - IMAGE = "image" - DOCUMENT = "document" - VIDEO = "video" - AUDIO = "audio" - CUSTOM = "custom" - - def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": - """ - fork a new tool with meta data - - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool - """ - return self.__class__( - identity=self.identity.model_copy() if self.identity else None, - parameters=self.parameters.copy() if self.parameters else None, - description=self.description.model_copy() if self.description else None, - runtime=Tool.Runtime(**runtime), - ) - - @abstractmethod - def tool_provider_type(self) -> ToolProviderType: - """ - get the tool provider type - - :return: the tool provider type - """ - - def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None: - """ - load variables from database - - :param conversation_id: the conversation id - """ - self.variables = variables - - def set_image_variable(self, variable_name: str, image_key: str) -> None: - """ - set an image variable - """ - if not self.variables: - return - if self.identity is None: - return - - self.variables.set_file(self.identity.name, variable_name, image_key) - - def set_text_variable(self, variable_name: str, text: str) -> None: - """ - set a text variable - """ - if not self.variables: - return - if self.identity is None: - return - - self.variables.set_text(self.identity.name, variable_name, text) - - def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: - """ - get a variable - - :param name: the name of the variable - :return: the variable - """ - if not self.variables: - return None - - if isinstance(name, Enum): - name = name.value - - for variable in self.variables.pool: - if variable.name == name: - return variable - - return None - - def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: - """ - get the default image variable - - :return: the image variable - """ - if not self.variables: - return None - - return self.get_variable(self.VariableKey.IMAGE) - - def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: - """ - get a variable file - - :param name: the name of the variable - :return: the variable file - """ - variable = self.get_variable(name) - if not variable: - return None - - if not isinstance(variable, ToolRuntimeImageVariable): - return None - - message_file_id = variable.value - # get file binary - file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) - if not file_binary: - return None - - return file_binary[0] - - def list_variables(self) -> list[ToolRuntimeVariable]: - """ - list all variables - - :return: the variables - """ - if not self.variables: - return [] - - return self.variables.pool - - def list_default_image_variables(self) -> list[ToolRuntimeVariable]: - """ - list all image variables - - :return: the image variables - """ - if not self.variables: - return [] - - result = [] - - for variable in self.variables.pool: - if variable.name.startswith(self.VariableKey.IMAGE.value): - result.append(variable) - - return result - - def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]: - # update tool_parameters - # TODO: Fix type error. - if self.runtime is None: - return [] - if self.runtime.runtime_parameters: - # Convert Mapping to dict before updating - tool_parameters = dict(tool_parameters) - tool_parameters.update(self.runtime.runtime_parameters) - - # try parse tool parameters into the correct type - tool_parameters = self._transform_tool_parameters_type(tool_parameters) - - result = self._invoke( - user_id=user_id, - tool_parameters=tool_parameters, - ) - - if not isinstance(result, list): - result = [result] - - if not all(isinstance(message, ToolInvokeMessage) for message in result): - raise ValueError( - f"Invalid return type from {self.__class__.__name__}._invoke method. " - "Expected ToolInvokeMessage or list of ToolInvokeMessage." - ) - - return result - - def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]: - """ - Transform tool parameters type - """ - # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials - result: dict[str, Any] = deepcopy(dict(tool_parameters)) - for parameter in self.parameters or []: - if parameter.name in tool_parameters: - result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) - - return result - - @abstractmethod - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - pass - - def validate_credentials( - self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False - ) -> str | None: - """ - validate the credentials - - :param credentials: the credentials - :param parameters: the parameters - :param format_only: only return the formatted - """ - pass - - def get_runtime_parameters(self) -> list[ToolParameter]: - """ - get the runtime parameters - - interface for developer to dynamic change the parameters of a tool depends on the variables pool - - :return: the runtime parameters - """ - return self.parameters or [] - - def get_all_runtime_parameters(self) -> list[ToolParameter]: - """ - get all runtime parameters - - :return: all runtime parameters - """ - parameters = self.parameters or [] - parameters = parameters.copy() - user_parameters = self.get_runtime_parameters() - user_parameters = user_parameters.copy() - - # override parameters - for parameter in user_parameters: - # check if parameter in tool parameters - found = False - for tool_parameter in parameters: - if tool_parameter.name == parameter.name: - found = True - break - - if found: - # override parameter - tool_parameter.type = parameter.type - tool_parameter.form = parameter.form - tool_parameter.required = parameter.required - tool_parameter.default = parameter.default - tool_parameter.options = parameter.options - tool_parameter.llm_description = parameter.llm_description - else: - # add new parameter - parameters.append(parameter) - - return parameters - - def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage: - """ - create an image message - - :param image: the url of the image - :return: the image message - """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as) - - def create_file_message(self, file: "File") -> ToolInvokeMessage: - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="") - - def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: - """ - create a link message - - :param link: the url of the link - :return: the link message - """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as) - - def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage: - """ - create a text message - - :param text: the text - :return: the text message - """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as) - - def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage: - """ - create a blob message - - :param blob: the blob - :return: the blob message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BLOB, - message=blob, - meta=meta or {}, - save_as=save_as, - ) - - def create_json_message(self, object: dict) -> ToolInvokeMessage: - """ - create a json message - """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index b3bcc3b2cc..5c672c985b 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -613,10 +613,10 @@ class Graph(BaseModel): for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): # check which node is after if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): - if node_id in merge_branch_node_ids: + if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id2] elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): - if node_id2 in merge_branch_node_ids: + if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id] branches_merge_node_ids: dict[str, str] = {} diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index dc919892e5..d19d6413ed 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file import File, FileTransferMethod, FileType +from core.file import File, FileTransferMethod from core.plugin.manager.exc import PluginDaemonClientSideError from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine @@ -189,10 +189,12 @@ class ToolNode(BaseNode[ToolNodeData]): conversation_id=None, ) - files: list[File] = [] text = "" + files: list[File] = [] json: list[dict] = [] + agent_logs: list[AgentLog] = [] + variables: dict[str, Any] = {} for message in message_stream: @@ -239,14 +241,16 @@ class ToolNode(BaseNode[ToolNodeData]): tool_file = session.scalar(stmt) if tool_file is None: raise ToolFileError(f"tool file {tool_file_id} not exists") + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + files.append( - File( + file_factory.build_from_mapping( + mapping=mapping, tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - extension=None, - mime_type=message.meta.get("mime_type", "application/octet-stream"), ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 30f216ff95..26bd6b3577 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -69,6 +69,7 @@ def init_app(app: DifyApp) -> Celery: "schedule.create_tidb_serverless_task", "schedule.update_tidb_serverless_status_task", "schedule.clean_messages", + "schedule.mail_clean_document_notify_task", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { @@ -92,6 +93,11 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.clean_messages.clean_messages", "schedule": timedelta(days=day), }, + # every Monday + "mail_clean_document_notify_task": { + "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", + "schedule": crontab(minute="0", hour="10", day_of_week="1"), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/models/tools.py b/api/models/tools.py index 3bc12e7fd7..2428ad0ac6 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Optional +from typing import Optional, Any import sqlalchemy as sa from deprecated import deprecated diff --git a/api/models/workflow.py b/api/models/workflow.py index 8a54553e3b..eba9c1b772 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -424,6 +424,18 @@ class WorkflowRun(Base): finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) + @property + def created_by_account(self): + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + + @property + def created_by_end_user(self): + from models.model import EndUser + + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + @property def graph_dict(self): return json.loads(self.graph) if self.graph else {} diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 48bdc872f4..5e4d3ec323 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -28,7 +28,6 @@ def clean_messages(): plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING ) - page = 1 while True: try: # Main query with join and filter @@ -79,4 +78,4 @@ def clean_messages(): db.session.query(Message).filter(Message.id == message.id).delete() db.session.commit() end_at = time.perf_counter() - click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) + click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 766954a257..fe6839288d 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -3,14 +3,18 @@ import time from collections import defaultdict import click -from celery import shared_task # type: ignore +from flask import render_template # type: ignore +import app +from configs import dify_config +from extensions.ext_database import db from extensions.ext_mail import mail from models.account import Account, Tenant, TenantAccountJoin from models.dataset import Dataset, DatasetAutoDisableLog +from services.feature_service import FeatureService -@shared_task(queue="mail") +@app.celery.task(queue="dataset") def send_document_clean_notify_task(): """ Async Send document clean notify mail @@ -29,35 +33,58 @@ def send_document_clean_notify_task(): # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) - + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): - knowledge_details = [] - tenant = Tenant.query.filter(Tenant.id == tenant_id).first() - if not tenant: - continue - current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() - if not current_owner_join: - continue - account = Account.query.filter(Account.id == current_owner_join.account_id).first() - if not account: - continue + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != "sandbox": + knowledge_details = [] + # check tenant + tenant = Tenant.query.filter(Tenant.id == tenant_id).first() + if not tenant: + continue + # check current owner + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + if not current_owner_join: + continue + account = Account.query.filter(Account.id == current_owner_join.account_id).first() + if not account: + continue - dataset_auto_dataset_map = {} # type: ignore + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + html_content = render_template( + "clean_document_job_mail_template-US.html", + userName=account.email, + knowledge_details=knowledge_details, + url=url, + ) + mail.send( + to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content + ) + + # update notified to True for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( - dataset_auto_disable_log.document_id - ) - - for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = Dataset.query.filter(Dataset.id == dataset_id).first() - if dataset: - document_count = len(document_ids) - knowledge_details.append(f"
  • Knowledge base {dataset.name}: {document_count} documents
  • ") - + dataset_auto_disable_log.notified = True + db.session.commit() end_at = time.perf_counter() logging.info( click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") ) except Exception: - logging.exception("Send invite member mail to failed") + logging.exception("Send document clean notify mail failed") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index d030a1dfa9..932d68bea1 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -196,6 +196,9 @@ class AppDslService: data["kind"] = "app" imported_version = data.get("version", "0.1.0") + # check if imported_version is a float-like string + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") status = _check_version_compatibility(imported_version) # Extract app data diff --git a/api/services/errors/base.py b/api/services/errors/base.py index 4d39f956b8..35ea28468e 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,6 +1,6 @@ from typing import Optional -class BaseServiceError(Exception): +class BaseServiceError(ValueError): def __init__(self, description: Optional[str] = None): self.description = description diff --git a/api/services/message_service.py b/api/services/message_service.py index c4447a84da..c17122ef64 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -152,6 +152,7 @@ class MessageService: @classmethod def create_feedback( cls, + *, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 063a3cb1b5..10ce73c208 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -64,7 +64,10 @@ class ToolTransformService: ) elif isinstance(provider, ToolProviderApiEntity): if provider.plugin_id: - provider.icon = ToolTransformService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon) + if isinstance(provider.icon, str): + provider.icon = ToolTransformService.get_plugin_icon_url( + tenant_id=tenant_id, filename=provider.icon + ) else: provider.icon = ToolTransformService.get_tool_provider_icon_url( provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 95649106e2..2de3d0ac55 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,6 +3,7 @@ import time from collections.abc import Callable, Generator, Sequence from datetime import UTC, datetime from typing import Any, Optional +from uuid import uuid4 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -333,6 +334,7 @@ class WorkflowService: error = e.error workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) workflow_node_execution.tenant_id = tenant_id workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value workflow_node_execution.index = 1 diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html index b7c9538f9f..88e78f41c7 100644 --- a/api/templates/clean_document_job_mail_template-US.html +++ b/api/templates/clean_document_job_mail_template-US.html @@ -45,14 +45,14 @@ .content ul li { margin-bottom: 10px; } - .cta-button { + .cta-button, .cta-button:hover, .cta-button:active, .cta-button:visited, .cta-button:focus { display: block; margin: 20px auto; padding: 10px 20px; background-color: #4e89f9; - color: #ffffff; + color: #ffffff !important; text-align: center; - text-decoration: none; + text-decoration: none !important; border-radius: 5px; width: fit-content; } @@ -69,7 +69,7 @@