diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 8824c5dba6..cf4fce3700 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -22,7 +22,7 @@ body: - type: input attributes: label: Dify version - placeholder: 0.6.11 + placeholder: 0.6.15 description: See about section in Dify console validations: required: true diff --git a/api/.env.example b/api/.env.example index 474798cef7..80ef185e51 100644 --- a/api/.env.example +++ b/api/.env.example @@ -216,6 +216,7 @@ UNSTRUCTURED_API_KEY= SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= +SSRF_DEFAULT_MAX_RETRIES=3 BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index a32b70bdc7..07688e9aeb 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,4 +1,5 @@ from typing import Any, Optional +from urllib.parse import quote_plus from pydantic import Field, NonNegativeInt, PositiveInt, computed_field from pydantic_settings import BaseSettings @@ -104,7 +105,7 @@ class DatabaseConfig: ).strip("&") db_extras = f"?{db_extras}" if db_extras else "" return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" - f"{self.DB_USERNAME}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" + f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" f"{db_extras}") SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field( diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index e6bf6cc3a3..afd383880f 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -34,12 +34,6 @@ class Segment(BaseModel): return str(self.value) def to_object(self) -> Any: - if isinstance(self.value, Segment): - return self.value.to_object() - if isinstance(self.value, list): - return [v.to_object() for v in self.value] - if isinstance(self.value, dict): - return {k: v.to_object() for k, v in self.value.items()} return self.value diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index b020914d84..5edaccc4d6 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -56,6 +56,9 @@ class ObjectVariable(Variable): # TODO: Use markdown code block return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + def to_object(self): + return {k: v.to_object() for k, v in self.value.items()} + class ArrayVariable(Variable): value_type: SegmentType = SegmentType.ARRAY @@ -65,6 +68,9 @@ class ArrayVariable(Variable): def markdown(self) -> str: return '\n'.join(['- ' + item.markdown for item in self.value]) + def to_object(self): + return [v.to_object() for v in self.value] + class FileVariable(Variable): value_type: SegmentType = SegmentType.FILE diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 019b27f28a..14ca8e943c 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,48 +1,75 @@ """ Proxy requests to avoid SSRF """ +import logging import os +import time import httpx SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') +SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) proxies = { 'http://': SSRF_PROXY_HTTP_URL, 'https://': SSRF_PROXY_HTTPS_URL } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None +BACKOFF_FACTOR = 0.5 +STATUS_FORCELIST = [429, 500, 502, 503, 504] -def make_request(method, url, **kwargs): - if SSRF_PROXY_ALL_URL: - return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) - elif proxies: - return httpx.request(method=method, url=url, proxies=proxies, **kwargs) - else: - return httpx.request(method=method, url=url, **kwargs) +def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + if "allow_redirects" in kwargs: + allow_redirects = kwargs.pop("allow_redirects") + if "follow_redirects" not in kwargs: + kwargs["follow_redirects"] = allow_redirects + + retries = 0 + while retries <= max_retries: + try: + if SSRF_PROXY_ALL_URL: + response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) + elif proxies: + response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) + else: + response = httpx.request(method=method, url=url, **kwargs) + + if response.status_code not in STATUS_FORCELIST: + return response + else: + logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") + + except httpx.RequestError as e: + logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") + + retries += 1 + if retries <= max_retries: + time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) + + raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}") -def get(url, **kwargs): - return make_request('GET', url, **kwargs) +def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('GET', url, max_retries=max_retries, **kwargs) -def post(url, **kwargs): - return make_request('POST', url, **kwargs) +def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('POST', url, max_retries=max_retries, **kwargs) -def put(url, **kwargs): - return make_request('PUT', url, **kwargs) +def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('PUT', url, max_retries=max_retries, **kwargs) -def patch(url, **kwargs): - return make_request('PATCH', url, **kwargs) +def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('PATCH', url, max_retries=max_retries, **kwargs) -def delete(url, **kwargs): - return make_request('DELETE', url, **kwargs) +def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('DELETE', url, max_retries=max_retries, **kwargs) -def head(url, **kwargs): - return make_request('HEAD', url, **kwargs) +def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('HEAD', url, max_retries=max_retries, **kwargs) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml index 3a79a929ba..c523596b57 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml @@ -10,10 +10,13 @@ - cohere.command-text-v14 - cohere.command-r-plus-v1.0 - cohere.command-r-v1.0 +- meta.llama3-1-8b-instruct-v1:0 +- meta.llama3-1-70b-instruct-v1:0 - meta.llama3-8b-instruct-v1:0 - meta.llama3-70b-instruct-v1:0 - meta.llama2-13b-chat-v1 - meta.llama2-70b-chat-v1 +- mistral.mistral-large-2407-v1:0 - mistral.mistral-small-2402-v1:0 - mistral.mistral-large-2402-v1:0 - mistral.mixtral-8x7b-instruct-v0:1 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 882d0b6352..e9906c8294 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -208,14 +208,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if model_info['support_tool_use'] and tools: parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools) + try: + if stream: + response = bedrock_client.converse_stream(**parameters) + return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) + else: + response = bedrock_client.converse(**parameters) + return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + except ClientError as ex: + error_code = ex.response['Error']['Code'] + full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" + raise self._map_client_to_invoke_error(error_code, full_error_msg) + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: + raise InvokeConnectionError(str(ex)) - if stream: - response = bedrock_client.converse_stream(**parameters) - return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) - else: - response = bedrock_client.converse(**parameters) - return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + except UnknownServiceError as ex: + raise InvokeServerUnavailableError(str(ex)) + except Exception as ex: + raise InvokeError(str(ex)) def _handle_converse_response(self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]) -> LLMResult: """ @@ -558,7 +569,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel): except ClientError as ex: error_code = ex.response['Error']['Code'] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" - raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) except Exception as ex: diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-70b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-70b-instruct-v1.0.yaml new file mode 100644 index 0000000000..10bfa7b1d5 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-70b-instruct-v1.0.yaml @@ -0,0 +1,25 @@ +model: meta.llama3-1-70b-instruct-v1:0 +label: + en_US: Llama 3.1 Instruct 70B +model_type: llm +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + - name: top_p + use_template: top_p + default: 0.9 + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.00265' + output: '0.0035' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-8b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-8b-instruct-v1.0.yaml new file mode 100644 index 0000000000..81cd53243f --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/meta.llama3-1-8b-instruct-v1.0.yaml @@ -0,0 +1,25 @@ +model: meta.llama3-1-8b-instruct-v1:0 +label: + en_US: Llama 3.1 Instruct 8B +model_type: llm +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + - name: top_p + use_template: top_p + default: 0.9 + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.0003' + output: '0.0006' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2407-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2407-v1.0.yaml new file mode 100644 index 0000000000..19d7843a57 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/mistral.mistral-large-2407-v1.0.yaml @@ -0,0 +1,29 @@ +model: mistral.mistral-large-2407-v1:0 +label: + en_US: Mistral Large 2 (24.07) +model_type: llm +features: + - tool-call +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + required: false + default: 0.7 + - name: top_p + use_template: top_p + required: false + default: 1 + - name: max_tokens + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.003' + output: '0.009' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 2b6d8e0047..8859dd72bd 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -14,6 +14,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, SystemPromptMessage, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.errors.invoke import InvokeError @@ -44,6 +45,17 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): "Stream": stream, **custom_parameters, } + # add Tools and ToolChoice + if (tools and len(tools) > 0): + params['ToolChoice'] = "auto" + params['Tools'] = [{ + "Type": "function", + "Function": { + "Name": tool.name, + "Description": tool.description, + "Parameters": json.dumps(tool.parameters) + } + } for tool in tools] request.from_json_string(json.dumps(params)) response = client.ChatCompletions(request) @@ -89,9 +101,43 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): def _convert_prompt_messages_to_dicts(self, prompt_messages: list[PromptMessage]) -> list[dict]: """Convert a list of PromptMessage objects to a list of dictionaries with 'Role' and 'Content' keys.""" - return [{"Role": message.role.value, "Content": message.content} for message in prompt_messages] + dict_list = [] + for message in prompt_messages: + if isinstance(message, AssistantPromptMessage): + tool_calls = message.tool_calls + if (tool_calls and len(tool_calls) > 0): + dict_tool_calls = [ + { + "Id": tool_call.id, + "Type": tool_call.type, + "Function": { + "Name": tool_call.function.name, + "Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}" + } + } for tool_call in tool_calls] + + dict_list.append({ + "Role": message.role.value, + # fix set content = "" while tool_call request + # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. + "Content": " ", # message.content if (message.content is not None) else "", + "ToolCalls": dict_tool_calls + }) + else: + dict_list.append({ "Role": message.role.value, "Content": message.content }) + elif isinstance(message, ToolPromptMessage): + tool_execute_result = { "result": message.content } + content =json.dumps(tool_execute_result, ensure_ascii=False) + dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id }) + else: + dict_list.append({ "Role": message.role.value, "Content": message.content }) + return dict_list def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp): + + tool_call = None + tool_calls = [] + for index, event in enumerate(resp): logging.debug("_handle_stream_chat_response, event: %s", event) @@ -109,20 +155,54 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): usage = data.get('Usage', {}) prompt_tokens = usage.get('PromptTokens', 0) completion_tokens = usage.get('CompletionTokens', 0) - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + response_tool_calls = delta.get('ToolCalls') + if (response_tool_calls is not None): + new_tool_calls = self._extract_response_tool_calls(response_tool_calls) + if (len(new_tool_calls) > 0): + new_tool_call = new_tool_calls[0] + if (tool_call is None): tool_call = new_tool_call + elif (tool_call.id != new_tool_call.id): + tool_calls.append(tool_call) + tool_call = new_tool_call + else: + tool_call.function.name += new_tool_call.function.name + tool_call.function.arguments += new_tool_call.function.arguments + if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0): + tool_calls.append(tool_call) + tool_call = None assistant_prompt_message = AssistantPromptMessage( content=message_content, tool_calls=[] ) + # rewrite content = "" while tool_call to avoid show content on web page + if (len(tool_calls) > 0): assistant_prompt_message.content = "" + + # add tool_calls to assistant_prompt_message + if (finish_reason == 'tool_calls'): + assistant_prompt_message.tool_calls = tool_calls + tool_call = None + tool_calls = [] - delta_chunk = LLMResultChunkDelta( - index=index, - role=delta.get('Role', 'assistant'), - message=assistant_prompt_message, - usage=usage, - finish_reason=finish_reason, - ) + if (len(finish_reason) > 0): + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + delta_chunk = LLMResultChunkDelta( + index=index, + role=delta.get('Role', 'assistant'), + message=assistant_prompt_message, + usage=usage, + finish_reason=finish_reason, + ) + tool_call = None + tool_calls = [] + + else: + delta_chunk = LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + ) yield LLMResultChunk( model=model, @@ -177,12 +257,15 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): """ human_prompt = "\n\nHuman:" ai_prompt = "\n\nAssistant:" + tool_prompt = "\n\nTool:" content = message.content if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" + elif isinstance(message, ToolPromptMessage): + message_text = f"{tool_prompt} {content}" elif isinstance(message, SystemPromptMessage): message_text = content else: @@ -203,3 +286,30 @@ class HunyuanLargeLanguageModel(LargeLanguageModel): return { InvokeError: [TencentCloudSDKException], } + + def _extract_response_tool_calls(self, + response_tool_calls: list[dict]) \ + -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + response_function = response_tool_call.get('Function', {}) + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function.get('Name', ''), + arguments=response_function.get('Arguments', '') + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.get('Id', 0), + type='function', + function=function + ) + tool_calls.append(tool_call) + + return tool_calls \ No newline at end of file diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index f65f57da60..4bd09b331d 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -55,7 +55,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) """ SQL_CREATE_INDEX = """ -CREATE INDEX idx_docs_{table_name} ON {table_name}(text) +CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text) INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS ('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER sys.my_chinese_vgram_lexer') """ @@ -248,7 +248,7 @@ class OracleVector(BaseVector): def delete(self) -> None: with self._get_cursor() as cur: - cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") def _create_collection(self, dimension: int): cache_key = f"vector_indexing_{self._collection_name}" diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index d01cf48fac..f7a08135f5 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -4,9 +4,8 @@ from pathlib import Path from typing import Union from urllib.parse import unquote -import requests - from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.csv_extractor import CSVExtractor from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -51,7 +50,7 @@ class ExtractProcessor: @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = requests.get(url, headers={ + response = ssrf_proxy.get(url, headers={ "User-Agent": USER_AGENT }) diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index faa1e64057..b24cf2e170 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -54,8 +54,16 @@ class MarkdownExtractor(BaseExtractor): current_header = None current_text = "" + code_block_flag = False for line in lines: + if line.startswith("```"): + code_block_flag = not code_block_flag + current_text += line + "\n" + continue + if code_block_flag: + current_text += line + "\n" + continue header_match = re.match(r"^#+\s", line) if header_match: if current_header is not None: diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index 8409129833..cee46cee23 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -60,11 +60,13 @@ class JinaReaderTool(BuiltinTool): if tool_parameters.get('no_cache', False): headers['X-No-Cache'] = 'true' + max_retries = tool_parameters.get('max_retries', 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, params=request_params, timeout=(10, 60), + max_retries=max_retries ) if tool_parameters.get('summary', False): diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml index 072e7f0528..58ad6d8694 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml @@ -150,3 +150,17 @@ parameters: pt_BR: Habilitar resumo para a saída llm_description: enable summary form: form + - name: max_retries + type: number + required: false + default: 3 + label: + en_US: Retry + zh_Hans: 重试 + pt_BR: Repetir + human_description: + en_US: Number of times to retry the request if it fails + zh_Hans: 请求失败时重试的次数 + pt_BR: Número de vezes para repetir a solicitação se falhar + llm_description: Number of times to retry the request if it fails + form: form diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index e6bc08147f..d4a81cd096 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -40,10 +40,12 @@ class JinaSearchTool(BuiltinTool): if tool_parameters.get('no_cache', False): headers['X-No-Cache'] = 'true' + max_retries = tool_parameters.get('max_retries', 3) response = ssrf_proxy.get( str(URL(self._jina_search_endpoint + query)), headers=headers, - timeout=(10, 60) + timeout=(10, 60), + max_retries=max_retries ) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml index da0a300c6c..2bc70e1be1 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml @@ -91,3 +91,17 @@ parameters: pt_BR: Ignorar o cache llm_description: bypass the cache form: form + - name: max_retries + type: number + required: false + default: 3 + label: + en_US: Retry + zh_Hans: 重试 + pt_BR: Repetir + human_description: + en_US: Number of times to retry the request if it fails + zh_Hans: 请求失败时重试的次数 + pt_BR: Número de vezes para repetir a solicitação se falhar + llm_description: Number of times to retry the request if it fails + form: form diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index e52082541a..f6f04271d6 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -11,11 +11,10 @@ from contextlib import contextmanager from urllib.parse import unquote import cloudscraper -import requests from bs4 import BeautifulSoup, CData, Comment, NavigableString -from newspaper import Article from regex import regex +from core.helper import ssrf_proxy from core.rag.extractor import extract_processor from core.rag.extractor.extract_processor import ExtractProcessor @@ -45,7 +44,7 @@ def get_url(url: str, user_agent: str = None) -> str: main_content_type = None supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] - response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) + response = ssrf_proxy.head(url, headers=headers, follow_redirects=True, timeout=(5, 10)) if response.status_code == 200: # check content-type @@ -67,10 +66,11 @@ def get_url(url: str, user_agent: str = None) -> str: if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: return ExtractProcessor.load_from_url(url, return_text=True) - response = requests.get(url, headers=headers, allow_redirects=True, timeout=(120, 300)) + response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() - response = scraper.get(url, headers=headers, allow_redirects=True, timeout=(120, 300)) + scraper.perform_request = ssrf_proxy.make_request + response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) if response.status_code != 200: return "URL returned status code {}.".format(response.status_code) @@ -78,7 +78,7 @@ def get_url(url: str, user_agent: str = None) -> str: a = extract_using_readabilipy(response.text) if not a['plain_text'] or not a['plain_text'].strip(): - return get_url_from_newspaper3k(url) + return '' res = FULL_TEMPLATE.format( title=a['title'], @@ -91,23 +91,6 @@ def get_url(url: str, user_agent: str = None) -> str: return res -def get_url_from_newspaper3k(url: str) -> str: - - a = Article(url) - a.download() - a.parse() - - res = FULL_TEMPLATE.format( - title=a.title, - authors=a.authors, - publish_date=a.publish_date, - top_image=a.top_image, - text=a.text, - ) - - return res - - def extract_using_readabilipy(html): with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: f_html.write(html) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 1bd126f842..238477117d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -125,11 +125,15 @@ class ToolNode(BaseNode): ] else: tool_input = node_data.tool_parameters[parameter_name] - segment_group = parser.convert_template( - template=str(tool_input.value), - variable_pool=variable_pool, - ) - result[parameter_name] = segment_group.log if for_log else segment_group.text + if tool_input.type == 'variable': + parameter_value = variable_pool.get(tool_input.value).value + else: + segment_group = parser.convert_template( + template=str(tool_input.value), + variable_pool=variable_pool, + ) + parameter_value = segment_group.log if for_log else segment_group.text + result[parameter_name] = parameter_value return result diff --git a/api/extensions/storage/tencent_storage.py b/api/extensions/storage/tencent_storage.py index 6d9fb80f5e..e2c1ca55e3 100644 --- a/api/extensions/storage/tencent_storage.py +++ b/api/extensions/storage/tencent_storage.py @@ -32,8 +32,7 @@ class TencentStorage(BaseStorage): def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - while chunk := response['Body'].get_stream(chunk_size=4096): - yield chunk + yield from response['Body'].get_stream(chunk_size=4096) return generate() diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index c98c332021..ff33a97ff2 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,10 +1,12 @@ from flask_restful import fields -from core.app.segments import SecretVariable, Variable +from core.app.segments import SecretVariable, SegmentType, Variable from core.helper import encrypter from fields.member_fields import simple_account_fields from libs.helper import TimestampField +ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) + class EnvironmentVariableField(fields.Raw): def format(self, value): @@ -16,14 +18,18 @@ class EnvironmentVariableField(fields.Raw): 'value': encrypter.obfuscated_token(value.value), 'value_type': value.value_type.value, } - elif isinstance(value, Variable): + if isinstance(value, Variable): return { 'id': value.id, 'name': value.name, 'value': value.value, 'value_type': value.value_type.value, } - return value + if isinstance(value, dict): + value_type = value.get('value_type') + if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: + raise ValueError(f'Unsupported environment variable value type: {value_type}') + return value environment_variable_fields = { diff --git a/api/tests/unit_tests/core/helper/__init__.py b/api/tests/unit_tests/core/helper/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py new file mode 100644 index 0000000000..d917bb1003 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -0,0 +1,52 @@ +import random +from unittest.mock import MagicMock, patch + +from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request + + +@patch('httpx.request') +def test_successful_request(mock_request): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + response = make_request('GET', 'http://example.com') + assert response.status_code == 200 + + +@patch('httpx.request') +def test_retry_exceed_max_retries(mock_request): + mock_response = MagicMock() + mock_response.status_code = 500 + + side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES + mock_request.side_effect = side_effects + + try: + make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) + raise AssertionError("Expected Exception not raised") + except Exception as e: + assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" + + +@patch('httpx.request') +def test_retry_logic_success(mock_request): + side_effects = [] + + for _ in range(SSRF_DEFAULT_MAX_RETRIES): + status_code = random.choice(STATUS_FORCELIST) + mock_response = MagicMock() + mock_response.status_code = status_code + side_effects.append(mock_response) + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + side_effects.append(mock_response_200) + + mock_request.side_effect = side_effects + + response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES) + + assert response.status_code == 200 + assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 + assert mock_request.call_args_list[0][1].get('method') == 'GET' diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index a4f8b6839f..9ef624ed97 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -96,7 +96,7 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar ...rest, type: type === InputVarType.textInput ? 'string' : type, key: variable, - name: label, + name: label as string, } if (payload.type === InputVarType.textInput) diff --git a/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx b/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx index 759a15213d..6fb58ba9a1 100644 --- a/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx +++ b/web/app/components/app/configuration/prompt-mode/advanced-mode-waring.tsx @@ -4,7 +4,6 @@ import React from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import I18n from '@/context/i18n' -import { FlipBackward } from '@/app/components/base/icons/src/vender/line/arrows' import { LanguagesSupported } from '@/i18n/language' type Props = { onReturnToSimpleMode: () => void @@ -38,7 +37,6 @@ const AdvancedModeWarning: FC = ({ onClick={onReturnToSimpleMode} className='shrink-0 flex items-center h-6 px-2 bg-indigo-600 shadow-xs border border-gray-200 rounded-lg text-white text-xs font-semibold cursor-pointer space-x-1' > -
{t('appDebug.promptMode.switchBack')}
void - onSave: (selectedPages: NotionPageSelectorValue[]) => void + onSave: (selectedPages: NotionPage[]) => void datasetId: string } const NotionPageSelectorModal = ({ @@ -20,12 +20,12 @@ const NotionPageSelectorModal = ({ datasetId, }: NotionPageSelectorModalProps) => { const { t } = useTranslation() - const [selectedPages, setSelectedPages] = useState([]) + const [selectedPages, setSelectedPages] = useState([]) const handleClose = () => { onClose() } - const handleSelectPage = (newSelectedPages: NotionPageSelectorValue[]) => { + const handleSelectPage = (newSelectedPages: NotionPage[]) => { setSelectedPages(newSelectedPages) } const handleSave = () => { diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index 24da8855fa..dee983690b 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -191,7 +191,7 @@ const SimpleSelect: FC = ({ onClick={(e) => { e.stopPropagation() setSelectedItem(null) - onSelect({ value: null }) + onSelect({ name: '', value: '' }) }} className="h-5 w-5 text-gray-400 cursor-pointer" aria-hidden="false" diff --git a/web/app/components/explore/category.tsx b/web/app/components/explore/category.tsx index 2b6cfbd9be..cf655c5333 100644 --- a/web/app/components/explore/category.tsx +++ b/web/app/components/explore/category.tsx @@ -28,7 +28,7 @@ const Category: FC = ({ allCategoriesEn, }) => { const { t } = useTranslation() - const isAllCategories = !list.includes(value) + const isAllCategories = !list.includes(value as AppCategory) const itemClassName = (isSelected: boolean) => cn( 'flex items-center px-3 py-[7px] h-[32px] rounded-lg border-[0.5px] border-transparent text-gray-700 font-medium leading-[18px] cursor-pointer hover:bg-gray-200', diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index abc81262b9..1547032163 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -12,6 +12,7 @@ export enum FormTypeEnum { secretInput = 'secret-input', select = 'select', radio = 'radio', + boolean = 'boolean', files = 'files', } diff --git a/web/app/components/header/dataset-nav/index.tsx b/web/app/components/header/dataset-nav/index.tsx index f415658eee..abf76608a8 100644 --- a/web/app/components/header/dataset-nav/index.tsx +++ b/web/app/components/header/dataset-nav/index.tsx @@ -11,6 +11,7 @@ import useSWR from 'swr' import useSWRInfinite from 'swr/infinite' import { flatten } from 'lodash-es' import Nav from '../nav' +import type { NavItem } from '../nav/nav-selector' import { fetchDatasetDetail, fetchDatasets } from '@/service/datasets' import type { DataSetListResponse } from '@/models/datasets' @@ -31,7 +32,7 @@ const DatasetNav = () => { datasetId, } : null, - apiParams => fetchDatasetDetail(apiParams.datasetId)) + apiParams => fetchDatasetDetail(apiParams.datasetId as string)) const { data: datasetsData, setSize } = useSWRInfinite(datasetId ? getKey : () => null, fetchDatasets, { revalidateFirstPage: false, revalidateAll: true }) const datasetItems = flatten(datasetsData?.map(datasetData => datasetData.data)) @@ -46,14 +47,14 @@ const DatasetNav = () => { text={t('common.menus.datasets')} activeSegment='datasets' link='/datasets' - curNav={currentDataset} + curNav={currentDataset as Omit} navs={datasetItems.map(dataset => ({ id: dataset.id, name: dataset.name, link: `/datasets/${dataset.id}/documents`, icon: dataset.icon, icon_background: dataset.icon_background, - }))} + })) as NavItem[]} createText={t('common.menus.newDataset')} onCreate={() => router.push('/datasets/create')} onLoadmore={handleLoadmore} diff --git a/web/app/components/header/nav/nav-selector/index.tsx b/web/app/components/header/nav/nav-selector/index.tsx index 51192c6580..26f538d72d 100644 --- a/web/app/components/header/nav/nav-selector/index.tsx +++ b/web/app/components/header/nav/nav-selector/index.tsx @@ -23,13 +23,13 @@ export type NavItem = { link: string icon: string icon_background: string - mode: string + mode?: string } export type INavSelectorProps = { navs: NavItem[] curNav?: Omit createText: string - isApp: boolean + isApp?: boolean onCreate: (state: string) => void onLoadmore?: () => void } diff --git a/web/app/components/tools/provider/card.tsx b/web/app/components/tools/provider/card.tsx index 7f87d65e3a..6a688186cf 100644 --- a/web/app/components/tools/provider/card.tsx +++ b/web/app/components/tools/provider/card.tsx @@ -36,7 +36,7 @@ const ProviderCard = ({ }, [collection.labels, labelList, language]) return ( -
+
{typeof collection.icon === 'string' && ( diff --git a/web/app/components/tools/provider/detail.tsx b/web/app/components/tools/provider/detail.tsx index ee02e4966d..546b9cd9a1 100644 --- a/web/app/components/tools/provider/detail.tsx +++ b/web/app/components/tools/provider/detail.tsx @@ -85,7 +85,7 @@ const ProviderDetail = ({ const [customCollection, setCustomCollection] = useState(null) const [isShowEditCollectionToolModal, setIsShowEditCustomCollectionModal] = useState(false) const [showConfirmDelete, setShowConfirmDelete] = useState(false) - const [deleteAction, setDeleteAction] = useState(null) + const [deleteAction, setDeleteAction] = useState('') const doUpdateCustomToolCollection = async (data: CustomCollectionBackend) => { await updateCustomCollection(data) onRefreshData() diff --git a/web/app/components/tools/workflow-tool/index.tsx b/web/app/components/tools/workflow-tool/index.tsx index 436b2c55ab..0f9fe4c4c1 100644 --- a/web/app/components/tools/workflow-tool/index.tsx +++ b/web/app/components/tools/workflow-tool/index.tsx @@ -173,7 +173,7 @@ const WorkflowToolAsModal: FC = ({
{t('tools.createTool.description')}