diff --git a/.gitignore b/.gitignore index 5512b49447..658c9319b2 100644 --- a/.gitignore +++ b/.gitignore @@ -145,6 +145,9 @@ docker/volumes/db/data/* docker/volumes/redis/data/* docker/volumes/weaviate/* docker/volumes/qdrant/* +docker/volumes/etcd/* +docker/volumes/minio/* +docker/volumes/milvus/* sdks/python-client/build sdks/python-client/dist diff --git a/api/app.py b/api/app.py index ac8bf27df1..aea28ac93a 100644 --- a/api/app.py +++ b/api/app.py @@ -26,6 +26,7 @@ from config import CloudEditionConfig, Config from extensions import ( ext_celery, ext_code_based_extension, + ext_compress, ext_database, ext_hosting_provider, ext_login, @@ -96,6 +97,7 @@ def create_app(test_config=None) -> Flask: def initialize_extensions(app): # Since the application instance is now created, pass it to each Flask # extension instance to bind it to the Flask application instance (app) + ext_compress.init_app(app) ext_code_based_extension.init() ext_database.init_app(app) ext_migrate.init(app, db) diff --git a/api/config.py b/api/config.py index 3f6980bdea..7c46426b47 100644 --- a/api/config.py +++ b/api/config.py @@ -90,7 +90,7 @@ class Config: # ------------------------ # General Configurations. # ------------------------ - self.CURRENT_VERSION = "0.5.7" + self.CURRENT_VERSION = "0.5.8" self.COMMIT_SHA = get_env('COMMIT_SHA') self.EDITION = "SELF_HOSTED" self.DEPLOY_ENV = get_env('DEPLOY_ENV') @@ -293,6 +293,8 @@ class Config: self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT') + self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') + class CloudEditionConfig(Config): diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d2c1891b65..77eaf136fc 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -88,7 +88,7 @@ class ChatMessageTextApi(Resource): response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, text=request.form['text'], - voice=app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=False ) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index f957d38174..dc546ce0dd 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -85,7 +85,7 @@ class ChatTextApi(InstalledAppResource): response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, text=request.form['text'], - voice=app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=False ) return {'data': response.data.decode('latin1')} diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 58ab56a292..60ca2171d5 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -87,7 +87,7 @@ class TextApi(Resource): tenant_id=app_model.tenant_id, text=args['text'], end_user=end_user, - voice=app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=args['voice'] if args['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=args['streaming'] ) diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index c628c16606..4e677ae288 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -84,7 +84,7 @@ class TextApi(WebApiResource): tenant_id=app_model.tenant_id, text=request.form['text'], end_user=end_user.external_user_id, - voice=app_model.app_model_config.text_to_speech_dict.get('voice'), + voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=False ) diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index 8fcbff983d..3762ddcf62 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -28,6 +28,9 @@ from models.model import Conversation, Message class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): + _is_first_iteration = True + _ignore_observation_providers = ['wenxin'] + def run(self, conversation: Conversation, message: Message, query: str, @@ -42,10 +45,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): agent_scratchpad: list[AgentScratchpadUnit] = [] self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) - # check model mode - if self.app_orchestration_config.model_config.mode == "completion": - # TODO: stop words - if 'Observation' not in app_orchestration_config.model_config.stop: + if 'Observation' not in app_orchestration_config.model_config.stop: + if app_orchestration_config.model_config.provider not in self._ignore_observation_providers: app_orchestration_config.model_config.stop.append('Observation') # override inputs @@ -202,6 +203,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): ) ) + scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you' agent_scratchpad.append(scratchpad) # get llm usage @@ -255,9 +257,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): # invoke tool error_response = None try: + if isinstance(tool_call_args, str): + try: + tool_call_args = json.loads(tool_call_args) + except json.JSONDecodeError: + pass + tool_response = tool_instance.invoke( user_id=self.user_id, - tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args) + tool_parameters=tool_call_args ) # transform tool response to llm friendly response tool_response = self.transform_tool_invoke_messages(tool_response) @@ -466,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): if isinstance(message, AssistantPromptMessage): current_scratchpad = AgentScratchpadUnit( agent_response=message.content, - thought=message.content, + thought=message.content or 'I am thinking about how to help you', action_str='', action=None, observation=None, @@ -546,7 +554,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): result = '' for scratchpad in agent_scratchpad: - result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n" + result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \ + next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available') return result @@ -621,21 +630,24 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): )) # add assistant message - if len(agent_scratchpad) > 0: + if len(agent_scratchpad) > 0 and not self._is_first_iteration: prompt_messages.append(AssistantPromptMessage( - content=(agent_scratchpad[-1].thought or '') + content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''), )) # add user message - if len(agent_scratchpad) > 0: + if len(agent_scratchpad) > 0 and not self._is_first_iteration: prompt_messages.append(UserPromptMessage( - content=(agent_scratchpad[-1].observation or ''), + content=(agent_scratchpad[-1].observation or 'It seems that no response is available'), )) + self._is_first_iteration = False + return prompt_messages elif mode == "completion": # parse agent scratchpad agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad) + self._is_first_iteration = False # parse prompt messages return [UserPromptMessage( content=first_prompt.replace("{{instruction}}", instruction) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f5ea49bb5e..0cd9f9f646 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -62,7 +62,8 @@ class IndexingRunner: text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) + documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, + processing_rule.to_dict()) # save segment self._load_segments(dataset, dataset_document, documents) @@ -120,7 +121,8 @@ class IndexingRunner: text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict()) + documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, + processing_rule.to_dict()) # save segment self._load_segments(dataset, dataset_document, documents) @@ -186,7 +188,7 @@ class IndexingRunner: first() index_type = dataset_document.doc_form - index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor() + index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( index_processor=index_processor, dataset=dataset, @@ -750,7 +752,7 @@ class IndexingRunner: index_processor.load(dataset, documents) def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, - text_docs: list[Document], process_rule: dict) -> list[Document]: + text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]: # get embedding model instance embedding_model_instance = None if dataset.indexing_technique == 'high_quality': @@ -768,7 +770,8 @@ class IndexingRunner: ) documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, - process_rule=process_rule) + process_rule=process_rule, tenant_id=dataset.tenant_id, + doc_language=doc_language) return documents diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.py b/api/core/model_runtime/model_providers/anthropic/anthropic.py index ece6d2a7a4..00a6bbce3b 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.py +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.py @@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider): # Use `claude-instant-1` model for validate, model_instance.validate_credentials( - model='claude-instant-1', + model='claude-instant-1.2', credentials=credentials ) except CredentialsValidateFailedError as ex: diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.yaml b/api/core/model_runtime/model_providers/anthropic/anthropic.yaml index d32b763301..cf41f544ef 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.yaml +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.yaml @@ -2,8 +2,8 @@ provider: anthropic label: en_US: Anthropic description: - en_US: Anthropic’s powerful models, such as Claude 2 and Claude Instant. - zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant。 + en_US: Anthropic’s powerful models, such as Claude 3. + zh_Hans: Anthropic 的强大模型,例如 Claude 3。 icon_small: en_US: icon_s_en.svg icon_large: diff --git a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml new file mode 100644 index 0000000000..e7b002878a --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml @@ -0,0 +1,6 @@ +- claude-3-opus-20240229 +- claude-3-sonnet-20240229 +- claude-2.1 +- claude-instant-1.2 +- claude-2 +- claude-instant-1 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml index 12faf60bc9..1986947129 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml @@ -34,3 +34,4 @@ pricing: output: '24.00' unit: '0.000001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml new file mode 100644 index 0000000000..ab3e92a059 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml @@ -0,0 +1,37 @@ +model: claude-3-opus-20240229 +label: + en_US: claude-3-opus-20240229 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 + - name: response_format + use_template: response_format +pricing: + input: '15.00' + output: '75.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml new file mode 100644 index 0000000000..65cdab9bc6 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml @@ -0,0 +1,37 @@ +model: claude-3-sonnet-20240229 +label: + en_US: claude-3-sonnet-20240229 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 + - name: response_format + use_template: response_format +pricing: + input: '3.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml new file mode 100644 index 0000000000..929a7f8725 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml @@ -0,0 +1,35 @@ +model: claude-instant-1.2 +label: + en_US: claude-instant-1.2 +model_type: llm +features: [ ] +model_properties: + mode: chat + context_size: 100000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 + - name: response_format + use_template: response_format +pricing: + input: '1.63' + output: '5.51' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml index 25d32a09af..5e76d5b1c2 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml @@ -33,3 +33,4 @@ pricing: output: '5.51' unit: '0.000001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 00e5ef6fda..6f9f41ca44 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,18 +1,32 @@ +import base64 +import mimetypes from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast import anthropic +import requests from anthropic import Anthropic, Stream -from anthropic.types import Completion, completion_create_params +from anthropic.types import ( + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + MessageStartEvent, + MessageStopEvent, + MessageStreamEvent, + completion_create_params, +) from httpx import Timeout from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, UserPromptMessage, ) from core.model_runtime.errors.invoke import ( @@ -35,6 +49,7 @@ if you are not sure about the structure. """ + class AnthropicLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, @@ -55,54 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: full response or stream response chunk generator result """ # invoke model - return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - + return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + + def _chat_generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + """ + Invoke llm chat model + + :param model: model name + :param credentials: credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + + # transform model parameters from completion api of anthropic to chat api + if 'max_tokens_to_sample' in model_parameters: + model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample') + + # init model client + client = Anthropic(**credentials_kwargs) + + extra_model_kwargs = {} + if stop: + extra_model_kwargs['stop_sequences'] = stop + + if user: + extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) + + system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages) + + if system: + extra_model_kwargs['system'] = system + + # chat model + response = client.messages.create( + model=model, + messages=prompt_message_dicts, + stream=stream, + **model_parameters, + **extra_model_kwargs + ) + + if stream: + return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_chat_generate_response(model, credentials, response, prompt_messages) + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ if 'response_format' in model_parameters and model_parameters['response_format']: stop = stop or [] - self._transform_json_prompts( - model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format'] + # chat model + self._transform_chat_json_prompts( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + response_format=model_parameters['response_format'] ) model_parameters.pop('response_format') return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: """ Transform json prompts """ if "```\n" not in stop: stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") # check if there is a system message if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{block}}", response_format) ) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) else: # insert the system message prompt_messages.insert(0, SystemPromptMessage( content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) + .replace("{{instructions}}", f"Please output a valid {response_format} object.") + .replace("{{block}}", response_format) )) - - prompt_messages.append(AssistantPromptMessage( - content=f"```{response_format}\n" - )) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -129,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): :return: """ try: - self._generate( + self._chat_generate( model=model, credentials=credentials, prompt_messages=[ @@ -137,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): ], model_parameters={ "temperature": 0, - "max_tokens_to_sample": 20, + "max_tokens": 20, }, stream=False ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message, + prompt_messages: list[PromptMessage]) -> LLMResult: """ - Invoke large language model - - :param model: model name - :param credentials: credentials kwargs - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :return: full response or stream response chunk generator result - """ - # transform credentials to kwargs for model instance - credentials_kwargs = self._to_credential_kwargs(credentials) - - client = Anthropic(**credentials_kwargs) - - extra_model_kwargs = {} - if stop: - extra_model_kwargs['stop_sequences'] = stop - - if user: - extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) - - response = client.completions.create( - model=model, - prompt=self._convert_messages_to_prompt_anthropic(prompt_messages), - stream=stream, - **model_parameters, - **extra_model_kwargs - ) - - if stream: - return self._handle_generate_stream_response(model, credentials, response, prompt_messages) - - return self._handle_generate_response(model, credentials, response, prompt_messages) - - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: - """ - Handle llm response + Handle llm chat response :param model: model name :param credentials: credentials @@ -198,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): """ # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=response.completion + content=response.content[0].text ) # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + if response.usage: + # transform usage + prompt_tokens = response.usage.input_tokens + completion_tokens = response.usage.output_tokens + else: + # calculate num tokens + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) # transform response - result = LLMResult( + response = LLMResult( model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, - usage=usage, + usage=usage ) - return result + return response - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response(self, model: str, credentials: dict, + response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage]) -> Generator: """ - Handle llm stream response + Handle llm chat stream response :param model: model name - :param credentials: credentials :param response: response :param prompt_messages: prompt messages - :return: llm response chunk generator result + :return: llm response chunk generator """ - index = -1 + full_assistant_content = '' + return_model = None + input_tokens = 0 + output_tokens = 0 + finish_reason = None + index = 0 for chunk in response: - content = chunk.completion - if chunk.stop_reason is None and (content is None or content == ''): - continue - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=content if content else '', - ) - - index += 1 - - if chunk.stop_reason is not None: - # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) - + if isinstance(chunk, MessageStartEvent): + return_model = chunk.message.model + input_tokens = chunk.message.usage.input_tokens + elif isinstance(chunk, MessageDeltaEvent): + output_tokens = chunk.usage.output_tokens + finish_reason = chunk.delta.stop_reason + elif isinstance(chunk, MessageStopEvent): # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) yield LLMResultChunk( - model=chunk.model, + model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=chunk.stop_reason, + index=index + 1, + message=AssistantPromptMessage( + content='' + ), + finish_reason=finish_reason, usage=usage ) ) - else: + elif isinstance(chunk, ContentBlockDeltaEvent): + chunk_text = chunk.delta.text if chunk.delta.text else '' + full_assistant_content += chunk_text + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=chunk_text + ) + + index = chunk.index + yield LLMResultChunk( - model=chunk.model, + model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message + index=chunk.index, + message=assistant_prompt_message, ) ) @@ -289,6 +337,80 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return credentials_kwargs + def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: + """ + Convert prompt messages to dict list and system + """ + system = "" + prompt_message_dicts = [] + + for message in prompt_messages: + if isinstance(message, SystemPromptMessage): + system += message.content + ("\n" if not system else "") + else: + prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) + + return system, prompt_message_dicts + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = { + "type": "text", + "text": message_content.data + } + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + if not message_content.data.startswith("data:"): + # fetch image data from url + try: + image_content = requests.get(message_content.data).content + mime_type, _ = mimetypes.guess_type(message_content.data) + base64_data = base64.b64encode(image_content).decode('utf-8') + except Exception as ex: + raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + else: + data_split = message_content.data.split(";base64,") + mime_type = data_split[0].replace("data:", "") + base64_data = data_split[1] + + if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + raise ValueError(f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp") + + sub_message_dict = { + "type": "image", + "source": { + "type": "base64", + "media_type": mime_type, + "data": base64_data + } + } + sub_messages.append(sub_message_dict) + + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + + return message_dict + def _convert_one_message_to_text(self, message: PromptMessage) -> str: """ Convert a single message to a string. diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index da4ba55881..535714f663 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -108,7 +108,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: - raise InvokeConnectionError(e) + raise InvokeConnectionError(str(e)) if response.status_code != 200: try: diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 5c146972cd..da922232c0 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -57,7 +57,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: - raise InvokeConnectionError(e) + raise InvokeConnectionError(str(e)) if response.status_code != 200: try: diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 39143127eb..c95007d271 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -59,7 +59,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): try: response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10) except Exception as e: - raise InvokeConnectionError(e) + raise InvokeConnectionError(str(e)) if response.status_code != 200: try: diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index edf4d6005a..85dc6ef51d 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -65,7 +65,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: - raise InvokeConnectionError(e) + raise InvokeConnectionError(str(e)) if response.status_code != 200: raise InvokeServerUnavailableError(response.text) diff --git a/api/core/model_runtime/model_providers/mistralai/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/mistralai/_assets/icon_s_en.png index de199b4317..de27b57512 100644 Binary files a/api/core/model_runtime/model_providers/mistralai/_assets/icon_s_en.png and b/api/core/model_runtime/model_providers/mistralai/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index b1718c063c..f5e2ec4b7c 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -34,7 +34,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) - if not voice: + if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: voice = self._get_model_default_voice(model, credentials) if streaming: return Response(stream_with_context(self._tts_invoke_streaming(model=model, diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 33847c0cb3..4dbd0678e7 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -53,7 +53,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): # cloud not connect to the server raise InvokeAuthorizationError(f"Invalid server URL: {e}") except Exception as e: - raise InvokeConnectionError(e) + raise InvokeConnectionError(str(e)) if response.status_code != 200: if response.status_code == 400: diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index 6bd17684fe..937f469bdf 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -34,7 +34,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) - if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: voice = self._get_model_default_voice(model, credentials) if streaming: return Response(stream_with_context(self._tts_invoke_streaming(model=model, diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 24a91af62c..66dab65804 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,10 +1,10 @@ -from os import path from threading import Lock from time import time from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.sessions import Session +from yarl import URL class XinferenceModelExtraParameter: @@ -55,7 +55,10 @@ class XinferenceHelper: get xinference model extra parameter like model_format and model_handle_type """ - url = path.join(server_url, 'v1/models', model_uid) + if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): + raise RuntimeError('model_uid is empty') + + url = str(URL(server_url) / 'v1' / 'models' / model_uid) # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() @@ -66,7 +69,6 @@ class XinferenceHelper: response = session.get(url, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') - if response.status_code != 200: raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index c391d7ae66..a8077971dc 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -3,6 +3,7 @@ import csv from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.extractor.helpers import detect_file_encodings from core.rag.models.document import Document @@ -36,7 +37,7 @@ class CSVExtractor(BaseExtractor): docs = self._read_from_file(csvfile) except UnicodeDecodeError as e: if self._autodetect_encoding: - detected_encodings = detect_filze_encodings(self._file_path) + detected_encodings = detect_file_encodings(self._file_path) for encoding in detected_encodings: try: with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index f61c728b49..0d81c419d6 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -7,7 +7,6 @@ from typing import Optional import pandas as pd from flask import Flask, current_app -from flask_login import current_user from werkzeug.datastructures import FileStorage from core.generator.llm_generator import LLMGenerator @@ -31,7 +30,7 @@ class QAIndexProcessor(BaseIndexProcessor): def transform(self, documents: list[Document], **kwargs) -> list[Document]: splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=None) + embedding_model_instance=kwargs.get('embedding_model_instance')) # Split the text documents into nodes. all_documents = [] @@ -66,10 +65,10 @@ class QAIndexProcessor(BaseIndexProcessor): for doc in sub_documents: document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ 'flask_app': current_app._get_current_object(), - 'tenant_id': current_user.current_tenant.id, + 'tenant_id': kwargs.get('tenant_id'), 'document_node': doc, 'all_qa_documents': all_qa_documents, - 'document_language': kwargs.get('document_language', 'English')}) + 'document_language': kwargs.get('doc_language', 'English')}) threads.append(document_format_thread) document_format_thread.start() for thread in threads: diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index e796e58e13..2f332b1c31 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -18,3 +18,4 @@ - vectorizer - gaode - wecom +- qrcode diff --git a/api/core/tools/provider/builtin/qrcode/_assets/icon.svg b/api/core/tools/provider/builtin/qrcode/_assets/icon.svg new file mode 100644 index 0000000000..d44bb0bca9 --- /dev/null +++ b/api/core/tools/provider/builtin/qrcode/_assets/icon.svg @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.py b/api/core/tools/provider/builtin/qrcode/qrcode.py new file mode 100644 index 0000000000..9fa7d01265 --- /dev/null +++ b/api/core/tools/provider/builtin/qrcode/qrcode.py @@ -0,0 +1,16 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.qrcode.tools.qrcode_generator import QRCodeGeneratorTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class QRCodeProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + QRCodeGeneratorTool().invoke(user_id='', + tool_parameters={ + 'content': 'Dify 123 😊' + }) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.yaml b/api/core/tools/provider/builtin/qrcode/qrcode.yaml new file mode 100644 index 0000000000..c117c3de74 --- /dev/null +++ b/api/core/tools/provider/builtin/qrcode/qrcode.yaml @@ -0,0 +1,12 @@ +identity: + author: Bowen Liang + name: qrcode + label: + en_US: QRCode + zh_Hans: 二维码工具 + pt_BR: QRCode + description: + en_US: A tool for generating QR code (quick-response code) image. + zh_Hans: 一个二维码工具 + pt_BR: A tool for generating QR code (quick-response code) image. + icon: icon.svg diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py new file mode 100644 index 0000000000..a86f17a999 --- /dev/null +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -0,0 +1,35 @@ +import io +import logging +from typing import Any, Union + +import qrcode +from qrcode.image.pure import PyPNGImage + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class QRCodeGeneratorTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + # get expression + content = tool_parameters.get('content', '') + if not content: + return self.create_text_message('Invalid parameter content') + + try: + img = qrcode.make(data=content, image_factory=PyPNGImage) + byte_stream = io.BytesIO() + img.save(byte_stream) + byte_array = byte_stream.getvalue() + return self.create_blob_message(blob=byte_array, + meta={'mime_type': 'image/png'}, + save_as=self.VARIABLE_KEY.IMAGE.value) + except Exception: + logging.exception(f'Failed to generate QR code for content: {content}') + return self.create_text_message('Failed to generate QR code') diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.yaml b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.yaml new file mode 100644 index 0000000000..ca562ac094 --- /dev/null +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.yaml @@ -0,0 +1,26 @@ +identity: + name: qrcode_generator + author: Bowen Liang + label: + en_US: QR Code Generator + zh_Hans: 二维码生成器 + pt_BR: QR Code Generator +description: + human: + en_US: A tool for generating QR code image + zh_Hans: 一个用于生成二维码的工具 + pt_BR: A tool for generating QR code image + llm: A tool for generating QR code image +parameters: + - name: content + type: string + required: true + label: + en_US: content text for QR code + zh_Hans: 二维码文本内容 + pt_BR: content text for QR code + human_description: + en_US: content text for QR code + zh_Hans: 二维码文本内容 + pt_BR: 二维码文本内容 + form: llm diff --git a/api/core/tools/provider/builtin/tavily/_assets/icon.png b/api/core/tools/provider/builtin/tavily/_assets/icon.png new file mode 100644 index 0000000000..fdb40ab568 Binary files /dev/null and b/api/core/tools/provider/builtin/tavily/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py new file mode 100644 index 0000000000..a013d41fcf --- /dev/null +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.tavily.tools.tavily_search import TavilySearchTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class TavilyProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + TavilySearchTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "query": "Sachin Tendulkar", + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tavily.yaml b/api/core/tools/provider/builtin/tavily/tavily.yaml new file mode 100644 index 0000000000..50826e37b3 --- /dev/null +++ b/api/core/tools/provider/builtin/tavily/tavily.yaml @@ -0,0 +1,29 @@ +identity: + author: Yash Parmar + name: tavily + label: + en_US: Tavily + zh_Hans: Tavily + pt_BR: Tavily + description: + en_US: Tavily + zh_Hans: Tavily + pt_BR: Tavily + icon: icon.png +credentials_for_provider: + tavily_api_key: + type: secret-input + required: true + label: + en_US: Tavily API key + zh_Hans: Tavily API key + pt_BR: Tavily API key + placeholder: + en_US: Please input your Tavily API key + zh_Hans: 请输入你的 Tavily API key + pt_BR: Please input your Tavily API key + help: + en_US: Get your Tavily API key from Tavily + zh_Hans: 从 TavilyApi 获取您的 Tavily API key + pt_BR: Get your Tavily API key from Tavily + url: https://docs.tavily.com/docs/tavily-api/introduction diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py new file mode 100644 index 0000000000..9a4d27376b --- /dev/null +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py @@ -0,0 +1,161 @@ +from typing import Any, Optional + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +TAVILY_API_URL = "https://api.tavily.com" + + +class TavilySearch: + """ + A class for performing search operations using the Tavily Search API. + + Args: + api_key (str): The API key for accessing the Tavily Search API. + + Methods: + raw_results: Retrieves raw search results from the Tavily Search API. + results: Retrieves cleaned search results from the Tavily Search API. + clean_results: Cleans the raw search results. + """ + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + + def raw_results( + self, + query: str, + max_results: Optional[int] = 3, + search_depth: Optional[str] = "advanced", + include_domains: Optional[list[str]] = [], + exclude_domains: Optional[list[str]] = [], + include_answer: Optional[bool] = False, + include_raw_content: Optional[bool] = False, + include_images: Optional[bool] = False, + ) -> dict: + """ + Retrieves raw search results from the Tavily Search API. + + Args: + query (str): The search query. + max_results (int, optional): The maximum number of results to retrieve. Defaults to 3. + search_depth (str, optional): The search depth. Defaults to "advanced". + include_domains (List[str], optional): The domains to include in the search. Defaults to []. + exclude_domains (List[str], optional): The domains to exclude from the search. Defaults to []. + include_answer (bool, optional): Whether to include answer in the search results. Defaults to False. + include_raw_content (bool, optional): Whether to include raw content in the search results. Defaults to False. + include_images (bool, optional): Whether to include images in the search results. Defaults to False. + + Returns: + dict: The raw search results. + + """ + params = { + "api_key": self.api_key, + "query": query, + "max_results": max_results, + "search_depth": search_depth, + "include_domains": include_domains, + "exclude_domains": exclude_domains, + "include_answer": include_answer, + "include_raw_content": include_raw_content, + "include_images": include_images, + } + response = requests.post(f"{TAVILY_API_URL}/search", json=params) + response.raise_for_status() + return response.json() + + def results( + self, + query: str, + max_results: Optional[int] = 3, + search_depth: Optional[str] = "advanced", + include_domains: Optional[list[str]] = [], + exclude_domains: Optional[list[str]] = [], + include_answer: Optional[bool] = False, + include_raw_content: Optional[bool] = False, + include_images: Optional[bool] = False, + ) -> list[dict]: + """ + Retrieves cleaned search results from the Tavily Search API. + + Args: + query (str): The search query. + max_results (int, optional): The maximum number of results to retrieve. Defaults to 3. + search_depth (str, optional): The search depth. Defaults to "advanced". + include_domains (List[str], optional): The domains to include in the search. Defaults to []. + exclude_domains (List[str], optional): The domains to exclude from the search. Defaults to []. + include_answer (bool, optional): Whether to include answer in the search results. Defaults to False. + include_raw_content (bool, optional): Whether to include raw content in the search results. Defaults to False. + include_images (bool, optional): Whether to include images in the search results. Defaults to False. + + Returns: + list: The cleaned search results. + + """ + raw_search_results = self.raw_results( + query, + max_results=max_results, + search_depth=search_depth, + include_domains=include_domains, + exclude_domains=exclude_domains, + include_answer=include_answer, + include_raw_content=include_raw_content, + include_images=include_images, + ) + return self.clean_results(raw_search_results["results"]) + + def clean_results(self, results: list[dict]) -> list[dict]: + """ + Cleans the raw search results. + + Args: + results (list): The raw search results. + + Returns: + list: The cleaned search results. + + """ + clean_results = [] + for result in results: + clean_results.append( + { + "url": result["url"], + "content": result["content"], + } + ) + # return clean results as a string + return "\n".join([f"{res['url']}\n{res['content']}" for res in clean_results]) + + +class TavilySearchTool(BuiltinTool): + """ + A tool for searching Tavily using a given query. + """ + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + Invokes the Tavily search tool with the given user ID and tool parameters. + + Args: + user_id (str): The ID of the user invoking the tool. + tool_parameters (Dict[str, Any]): The parameters for the Tavily search tool. + + Returns: + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily search tool invocation. + """ + query = tool_parameters.get("query", "") + api_key = self.runtime.credentials["tavily_api_key"] + if not query: + return self.create_text_message("Please input query") + tavily_search = TavilySearch(api_key) + results = tavily_search.results(query) + print(results) + if not results: + return self.create_text_message(f"No results found for '{query}' in Tavily") + else: + return self.create_text_message(text=results) diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.yaml b/api/core/tools/provider/builtin/tavily/tools/tavily_search.yaml new file mode 100644 index 0000000000..ccdb9408fc --- /dev/null +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.yaml @@ -0,0 +1,27 @@ +identity: + name: tavily_search + author: Yash Parmar + label: + en_US: TavilySearch + zh_Hans: TavilySearch + pt_BR: TavilySearch +description: + human: + en_US: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed. + zh_Hans: 专为人工智能代理 (LLM) 构建的搜索引擎工具,可快速提供实时、准确和真实的结果。 + pt_BR: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed. + llm: A tool for search engine built specifically for AI agents (LLMs), delivering real-time, accurate, and factual results at speed. +parameters: + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + pt_BR: Query string + human_description: + en_US: used for searching + zh_Hans: 用于搜索网页内容 + pt_BR: used for searching + llm_description: key words for searching + form: llm diff --git a/api/core/tools/provider/builtin/twilio/_assets/icon.svg b/api/core/tools/provider/builtin/twilio/_assets/icon.svg new file mode 100644 index 0000000000..a1e2bd12c2 --- /dev/null +++ b/api/core/tools/provider/builtin/twilio/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py new file mode 100644 index 0000000000..984ac3e906 --- /dev/null +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -0,0 +1,41 @@ +from typing import Any, Union + +from langchain.utilities import TwilioAPIWrapper + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendMessageTool(BuiltinTool): + """ + A tool for sending messages using Twilio API. + + Args: + user_id (str): The ID of the user invoking the tool. + tool_parameters (Dict[str, Any]): The parameters required for sending the message. + + Returns: + Union[ToolInvokeMessage, List[ToolInvokeMessage]]: The result of invoking the tool, which includes the status of the message sending operation. + """ + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + account_sid = self.runtime.credentials["account_sid"] + auth_token = self.runtime.credentials["auth_token"] + from_number = self.runtime.credentials["from_number"] + + message = tool_parameters["message"] + to_number = tool_parameters["to_number"] + + if to_number.startswith("whatsapp:"): + from_number = f"whatsapp: {from_number}" + + twilio = TwilioAPIWrapper( + account_sid=account_sid, auth_token=auth_token, from_number=from_number + ) + + # Sending the message through Twilio + result = twilio.run(message, to_number) + + return self.create_text_message(text="Message sent successfully.") diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.yaml b/api/core/tools/provider/builtin/twilio/tools/send_message.yaml new file mode 100644 index 0000000000..e129698c86 --- /dev/null +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.yaml @@ -0,0 +1,40 @@ +identity: + name: send_message + author: Yash Parmar + label: + en_US: SendMessage + zh_Hans: 发送消息 + pt_BR: SendMessage +description: + human: + en_US: Send SMS or Twilio Messaging Channels messages. + zh_Hans: 发送SMS或Twilio消息通道消息。 + pt_BR: Send SMS or Twilio Messaging Channels messages. + llm: Send SMS or Twilio Messaging Channels messages. Supports different channels including WhatsApp. +parameters: + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息内容 + pt_BR: Message + human_description: + en_US: The content of the message to be sent. + zh_Hans: 要发送的消息内容。 + pt_BR: The content of the message to be sent. + llm_description: The content of the message to be sent. + form: llm + - name: to_number + type: string + required: true + label: + en_US: To Number + zh_Hans: 收信号码 + pt_BR: Para Número + human_description: + en_US: The recipient's phone number. Prefix with 'whatsapp:' for WhatsApp messages, e.g., "whatsapp:+1234567890". + zh_Hans: 收件人的电话号码。WhatsApp消息前缀为'whatsapp:',例如,"whatsapp:+1234567890"。 + pt_BR: The recipient's phone number. Prefix with 'whatsapp:' for WhatsApp messages, e.g., "whatsapp:+1234567890". + llm_description: The recipient's phone number. Prefix with 'whatsapp:' for WhatsApp messages, e.g., "whatsapp:+1234567890". + form: llm diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py new file mode 100644 index 0000000000..dbf30962f9 --- /dev/null +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class TwilioProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + """ + SendMessageTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id="", + tool_parameters={ + "message": "Credential validation message", + "to_number": "+14846624384", + }, + ) + """ + pass + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/twilio/twilio.yaml b/api/core/tools/provider/builtin/twilio/twilio.yaml new file mode 100644 index 0000000000..b5143c8736 --- /dev/null +++ b/api/core/tools/provider/builtin/twilio/twilio.yaml @@ -0,0 +1,46 @@ +identity: + author: Yash Parmar + name: twilio + label: + en_US: Twilio + zh_Hans: Twilio + pt_BR: Twilio + description: + en_US: Send messages through SMS or Twilio Messaging Channels. + zh_Hans: 通过SMS或Twilio消息通道发送消息。 + pt_BR: Send messages through SMS or Twilio Messaging Channels. + icon: icon.svg +credentials_for_provider: + account_sid: + type: secret-input + required: true + label: + en_US: Account SID + zh_Hans: 账户SID + pt_BR: Account SID + placeholder: + en_US: Please input your Twilio Account SID + zh_Hans: 请输入您的Twilio账户SID + pt_BR: Please input your Twilio Account SID + auth_token: + type: secret-input + required: true + label: + en_US: Auth Token + zh_Hans: 认证令牌 + pt_BR: Auth Token + placeholder: + en_US: Please input your Twilio Auth Token + zh_Hans: 请输入您的Twilio认证令牌 + pt_BR: Please input your Twilio Auth Token + from_number: + type: secret-input + required: true + label: + en_US: From Number + zh_Hans: 发信号码 + pt_BR: De Número + placeholder: + en_US: Please input your Twilio phone number + zh_Hans: 请输入您的Twilio电话号码 + pt_BR: Please input your Twilio phone number diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 9f343d6000..192793897e 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -174,7 +174,18 @@ class Tool(BaseModel, ABC): return result - def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: + def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]: + # check if tool_parameters is a string + if isinstance(tool_parameters, str): + # check if this tool has only one parameter + parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM] + if parameters and len(parameters) == 1: + tool_parameters = { + parameters[0].name: tool_parameters + } + else: + raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") + # update tool_parameters if self.runtime.runtime_parameters: tool_parameters.update(self.runtime.runtime_parameters) diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py new file mode 100644 index 0000000000..caa61675fb --- /dev/null +++ b/api/extensions/ext_compress.py @@ -0,0 +1,10 @@ +from flask import Flask + + +def init_app(app: Flask): + if app.config.get('API_COMPRESSION_ENABLED', False): + from flask_compress import Compress + + compress = Compress() + compress.init_app(app) + diff --git a/api/requirements.txt b/api/requirements.txt index ae5c77137a..9721c3a13d 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -3,6 +3,7 @@ beautifulsoup4==4.12.2 flask~=3.0.1 Flask-SQLAlchemy~=3.0.5 SQLAlchemy~=1.4.28 +Flask-Compress~=1.14 flask-login~=0.6.3 flask-migrate~=4.0.5 flask-restful~=0.3.10 @@ -35,7 +36,7 @@ docx2txt==0.8 pypdfium2==4.16.0 resend~=0.7.0 pyjwt~=2.8.0 -anthropic~=0.7.7 +anthropic~=0.17.0 newspaper3k==0.2.8 google-api-python-client==2.90.0 wikipedia==1.4.0 @@ -67,4 +68,7 @@ pydub~=0.25.1 gmpy2~=2.1.5 numexpr~=2.9.0 duckduckgo-search==4.4.3 -arxiv==2.1.0 \ No newline at end of file +arxiv==2.1.0 +yarl~=1.9.4 +twilio==9.0.0 +qrcode~=7.4.2 diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 96fd8f2026..2247d33e24 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -1,52 +1,87 @@ import os from time import sleep -from typing import Any, Generator, List, Literal, Union +from typing import Any, Literal, Union, Iterable + +from anthropic.resources import Messages +from anthropic.types.message_delta_event import Delta import anthropic import pytest from _pytest.monkeypatch import MonkeyPatch -from anthropic import Anthropic -from anthropic._types import NOT_GIVEN, Body, Headers, NotGiven, Query -from anthropic.resources.completions import Completions -from anthropic.types import Completion, completion_create_params +from anthropic import Anthropic, Stream +from anthropic.types import MessageParam, Message, MessageStreamEvent, \ + ContentBlock, MessageStartEvent, Usage, TextDelta, MessageDeltaEvent, MessageStopEvent, ContentBlockDeltaEvent, \ + MessageDeltaUsage MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' + class MockAnthropicClass(object): @staticmethod - def mocked_anthropic_chat_create_sync(model: str) -> Completion: - return Completion( - completion='hello, I\'m a chatbot from anthropic', + def mocked_anthropic_chat_create_sync(model: str) -> Message: + return Message( + id='msg-123', + type='message', + role='assistant', + content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')], model=model, - stop_reason='stop_sequence' + stop_reason='stop_sequence', + usage=Usage( + input_tokens=1, + output_tokens=1 + ) ) @staticmethod - def mocked_anthropic_chat_create_stream(model: str) -> Generator[Completion, None, None]: + def mocked_anthropic_chat_create_stream(model: str) -> Stream[MessageStreamEvent]: full_response_text = "hello, I'm a chatbot from anthropic" - for i in range(0, len(full_response_text) + 1): - sleep(0.1) - if i == len(full_response_text): - yield Completion( - completion='', - model=model, - stop_reason='stop_sequence' - ) - else: - yield Completion( - completion=full_response_text[i], - model=model, - stop_reason='' + yield MessageStartEvent( + type='message_start', + message=Message( + id='msg-123', + content=[], + role='assistant', + model=model, + stop_reason=None, + type='message', + usage=Usage( + input_tokens=1, + output_tokens=1 ) + ) + ) - def mocked_anthropic(self: Completions, *, - max_tokens_to_sample: int, - model: Union[str, Literal["claude-2.1", "claude-instant-1"]], - prompt: str, - stream: Literal[True], - **kwargs: Any - ) -> Union[Completion, Generator[Completion, None, None]]: + index = 0 + for i in range(0, len(full_response_text)): + sleep(0.1) + yield ContentBlockDeltaEvent( + type='content_block_delta', + delta=TextDelta(text=full_response_text[i], type='text_delta'), + index=index + ) + + index += 1 + + yield MessageDeltaEvent( + type='message_delta', + delta=Delta( + stop_reason='stop_sequence' + ), + usage=MessageDeltaUsage( + output_tokens=1 + ) + ) + + yield MessageStopEvent(type='message_stop') + + def mocked_anthropic(self: Messages, *, + max_tokens: int, + messages: Iterable[MessageParam], + model: str, + stream: Literal[True], + **kwargs: Any + ) -> Union[Message, Stream[MessageStreamEvent]]: if len(self._client.api_key) < 18: raise anthropic.AuthenticationError('Invalid API key') @@ -55,12 +90,13 @@ class MockAnthropicClass(object): else: return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model) + @pytest.fixture def setup_anthropic_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(Completions, 'create', MockAnthropicClass.mocked_anthropic) + monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic) yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index e4cc2ceea6..bba5704d2e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -32,68 +32,70 @@ class MockXinferenceClass(object): response = Response() if 'v1/models/' in url: # get model uid - model_uid = url.split('/')[-1] + model_uid = url.split('/')[-1] or '' if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ model_uid not in ['generate', 'chat', 'embedding', 'rerank']: response.status_code = 404 + response._content = b'{}' return response # check if url is valid if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): response.status_code = 404 + response._content = b'{}' return response if model_uid in ['generate', 'chat']: response.status_code = 200 response._content = b'''{ - "model_type": "LLM", - "address": "127.0.0.1:43877", - "accelerators": [ - "0", - "1" - ], - "model_name": "chatglm3-6b", - "model_lang": [ - "en" - ], - "model_ability": [ - "generate", - "chat" - ], - "model_description": "latest chatglm3", - "model_format": "pytorch", - "model_size_in_billions": 7, - "quantization": "none", - "model_hub": "huggingface", - "revision": null, - "context_length": 2048, - "replica": 1 - }''' + "model_type": "LLM", + "address": "127.0.0.1:43877", + "accelerators": [ + "0", + "1" + ], + "model_name": "chatglm3-6b", + "model_lang": [ + "en" + ], + "model_ability": [ + "generate", + "chat" + ], + "model_description": "latest chatglm3", + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantization": "none", + "model_hub": "huggingface", + "revision": null, + "context_length": 2048, + "replica": 1 + }''' return response elif model_uid == 'embedding': response.status_code = 200 response._content = b'''{ - "model_type": "embedding", - "address": "127.0.0.1:43877", - "accelerators": [ - "0", - "1" - ], - "model_name": "bge", - "model_lang": [ - "en" - ], - "revision": null, - "max_tokens": 512 -}''' + "model_type": "embedding", + "address": "127.0.0.1:43877", + "accelerators": [ + "0", + "1" + ], + "model_name": "bge", + "model_lang": [ + "en" + ], + "revision": null, + "max_tokens": 512 + }''' return response elif 'v1/cluster/auth' in url: response.status_code = 200 response._content = b'''{ - "auth": true -}''' + "auth": true + }''' return response def _check_cluster_authenticated(self): diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index ddba2a40ce..b3f6414800 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -15,14 +15,14 @@ def test_validate_credentials(setup_anthropic_mock): with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': 'invalid_key' } ) model.validate_credentials( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') } @@ -33,7 +33,7 @@ def test_invoke_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'), 'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL') @@ -49,7 +49,7 @@ def test_invoke_model(setup_anthropic_mock): model_parameters={ 'temperature': 0.0, 'top_p': 1.0, - 'max_tokens_to_sample': 10 + 'max_tokens': 10 }, stop=['How'], stream=False, @@ -64,7 +64,7 @@ def test_invoke_stream_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() response = model.invoke( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') }, @@ -78,7 +78,7 @@ def test_invoke_stream_model(setup_anthropic_mock): ], model_parameters={ 'temperature': 0.0, - 'max_tokens_to_sample': 100 + 'max_tokens': 100 }, stream=True, user="abc-123" @@ -97,7 +97,7 @@ def test_get_num_tokens(): model = AnthropicLargeLanguageModel() num_tokens = model.get_num_tokens( - model='claude-instant-1', + model='claude-instant-1.2', credentials={ 'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY') }, diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7cd09fd6ea..dfa01b6cef 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3.1' services: # API service api: - image: langgenius/dify-api:0.5.7 + image: langgenius/dify-api:0.5.8 restart: always environment: # Startup mode, 'api' starts the API server. @@ -135,7 +135,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.5.7 + image: langgenius/dify-api:0.5.8 restart: always environment: # Startup mode, 'worker' starts the Celery worker for processing the queue. @@ -206,7 +206,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.5.7 + image: langgenius/dify-web:0.5.8 restart: always environment: EDITION: SELF_HOSTED diff --git a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx index 6bd40547ca..02cac061b4 100644 --- a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx @@ -40,6 +40,7 @@ const TextToSpeech: FC = () => { { languageInfo?.example && ( )} diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index c10755d3a1..1492b859fa 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -9,12 +9,14 @@ import { textToAudio } from '@/service/share' type AudioBtnProps = { value: string + voice?: string className?: string isAudition?: boolean } const AudioBtn = ({ value, + voice, className, isAudition, }: AudioBtnProps) => { @@ -27,13 +29,16 @@ const AudioBtn = ({ const pathname = usePathname() const removeCodeBlocks = (inputText: any) => { const codeBlockRegex = /```[\s\S]*?```/g - return inputText.replace(codeBlockRegex, '') + if (inputText) + return inputText.replace(codeBlockRegex, '') + return '' } const playAudio = async () => { const formData = new FormData() if (value !== '') { formData.append('text', removeCodeBlocks(value)) + formData.append('voice', removeCodeBlocks(voice)) let url = '' let isPublic = false @@ -56,13 +61,14 @@ const AudioBtn = ({ const audioUrl = URL.createObjectURL(blob) const audio = new Audio(audioUrl) audioRef.current = audio - audio.play().then(() => { - setIsPlaying(true) - }).catch(() => { + audio.play().then(() => {}).catch(() => { setIsPlaying(false) URL.revokeObjectURL(audioUrl) }) - audio.onended = () => setHasEnded(true) + audio.onended = () => { + setHasEnded(true) + setIsPlaying(false) + } } catch (error) { setIsPlaying(false) @@ -70,24 +76,34 @@ const AudioBtn = ({ } } } - const togglePlayPause = () => { if (audioRef.current) { if (isPlaying) { - setPause(true) - audioRef.current.pause() - } - else if (!hasEnded) { - setPause(false) - audioRef.current.play() + if (!hasEnded) { + setPause(false) + audioRef.current.play() + } + if (!isPause) { + setPause(true) + audioRef.current.pause() + } } else if (!isPlaying) { - playAudio().then() + if (isPause) { + setPause(false) + audioRef.current.play() + } + else { + setHasEnded(false) + playAudio().then() + } } setIsPlaying(prevIsPlaying => !prevIsPlaying) } else { - playAudio().then() + setIsPlaying(true) + if (!isPlaying) + playAudio().then() } } @@ -102,7 +118,7 @@ const AudioBtn = ({ className={`box-border p-0.5 flex items-center justify-center cursor-pointer ${isAudition || 'rounded-md bg-white'}`} style={{ boxShadow: !isAudition ? '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)' : '' }} onClick={togglePlayPause}> -
+
diff --git a/web/app/components/base/audio-btn/style.module.css b/web/app/components/base/audio-btn/style.module.css index 7c05003b04..7e3175aa13 100644 --- a/web/app/components/base/audio-btn/style.module.css +++ b/web/app/components/base/audio-btn/style.module.css @@ -8,9 +8,3 @@ background-position: center; background-repeat: no-repeat; } - -.stopIcon { - background-position: center; - background-repeat: no-repeat; - background-image: url(~@/app/components/develop/secret-key/assets/stop.svg); -} \ No newline at end of file diff --git a/web/app/components/base/chat/chat/answer/operation.tsx b/web/app/components/base/chat/chat/answer/operation.tsx index eb5dead657..8a791d82da 100644 --- a/web/app/components/base/chat/chat/answer/operation.tsx +++ b/web/app/components/base/chat/chat/answer/operation.tsx @@ -77,6 +77,7 @@ const Operation: FC = ({ {(!isOpeningStatement && config?.text_to_speech?.enabled) && ( )} diff --git a/web/app/components/base/chat/chat/index.tsx b/web/app/components/base/chat/chat/index.tsx index 99032e61ed..67a91e2150 100644 --- a/web/app/components/base/chat/chat/index.tsx +++ b/web/app/components/base/chat/chat/index.tsx @@ -9,6 +9,7 @@ import { } from 'react' import { useTranslation } from 'react-i18next' import { useThrottleEffect } from 'ahooks' +import { debounce } from 'lodash-es' import type { ChatConfig, ChatItem, @@ -81,16 +82,24 @@ const Chat: FC = ({ chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight } - useThrottleEffect(() => { - handleScrolltoBottom() - + const handleWindowResize = () => { if (chatContainerRef.current && chatFooterRef.current) chatFooterRef.current.style.width = `${chatContainerRef.current.clientWidth}px` if (chatContainerInnerRef.current && chatFooterInnerRef.current) chatFooterInnerRef.current.style.width = `${chatContainerInnerRef.current.clientWidth}px` + } + + useThrottleEffect(() => { + handleScrolltoBottom() + handleWindowResize() }, [chatList], { wait: 500 }) + useEffect(() => { + window.addEventListener('resize', debounce(handleWindowResize)) + return () => window.removeEventListener('resize', handleWindowResize) + }, []) + useEffect(() => { if (chatFooterRef.current && chatContainerRef.current) { const resizeObserver = new ResizeObserver((entries) => { diff --git a/web/app/components/develop/secret-key/assets/play.svg b/web/app/components/develop/secret-key/assets/play.svg index 0ab33af6c6..b423e98ce2 100644 --- a/web/app/components/develop/secret-key/assets/play.svg +++ b/web/app/components/develop/secret-key/assets/play.svg @@ -1,7 +1,7 @@ - - + + diff --git a/web/app/components/develop/secret-key/assets/stop.svg b/web/app/components/develop/secret-key/assets/stop.svg deleted file mode 100644 index b423e98ce2..0000000000 --- a/web/app/components/develop/secret-key/assets/stop.svg +++ /dev/null @@ -1,11 +0,0 @@ - - - - - - - - - - - diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index e102108154..ca400d1438 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -71,8 +71,8 @@ Chat applications support session persistence, allowing previous chat history to - `upload_file_id` (string) Uploaded file ID, which must be obtained by uploading through the File Upload API in advance (when the transfer method is `local_file`) - Auto-generate title, default is `false`. - Can achieve async title generation by calling the conversation rename API and setting `auto_generate` to true. + Auto-generate title, default is `true`. + If set to `false`, can achieve async title generation by calling the conversation rename API and setting `auto_generate` to `true`. diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index 7bc3cd5337..dcc33ecf0b 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -71,7 +71,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `upload_file_id` 上传文件 ID。(仅当传递方式为 `local_file `时)。 - (选填)自动生成标题,默认 `false`。 可通过调用会话重命名接口并设置 `auto_generate` 为 `true` 实现异步生成标题。 + (选填)自动生成标题,默认 `true`。 若设置为 `false`,则可通过调用会话重命名接口并设置 `auto_generate` 为 `true` 实现异步生成标题。 diff --git a/web/package.json b/web/package.json index 66bb78383c..84182ecd3b 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.5.7", + "version": "0.5.8", "private": true, "scripts": { "dev": "next dev",