mirror of https://github.com/langgenius/dify.git
refactor: replace try-except blocks with contextlib.suppress for cleaner exception handling (#24284)
This commit is contained in:
parent
ad8e82ee1d
commit
1abf1240b2
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
|
@ -178,7 +179,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
|||
def cloud_utm_record(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
|
|
@ -187,8 +188,7 @@ def cloud_utm_record(view):
|
|||
if utm_info:
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
|
@ -97,10 +98,8 @@ def parse_traceparent_header(traceparent: str) -> Optional[str]:
|
|||
Reference:
|
||||
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
|
||||
"""
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
parts = traceparent.split("-")
|
||||
if len(parts) == 4 and len(parts[1]) == 32:
|
||||
return parts[1]
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
|
|
@ -624,14 +625,12 @@ class ProviderManager:
|
|||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable) or "", # type: ignore
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(credentials=provider_credentials)
|
||||
|
|
@ -672,14 +671,12 @@ class ProviderManager:
|
|||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
|
|
@ -214,10 +215,8 @@ class ClickzettaConnectionPool:
|
|||
return connection
|
||||
else:
|
||||
# Connection expired or invalid, close it
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# No valid connection found, create new one
|
||||
return self._create_connection(config)
|
||||
|
|
@ -228,10 +227,8 @@ class ClickzettaConnectionPool:
|
|||
|
||||
if config_key not in self._pool_locks:
|
||||
# Pool was cleaned up, just close the connection
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
with self._pool_locks[config_key]:
|
||||
|
|
@ -243,10 +240,8 @@ class ClickzettaConnectionPool:
|
|||
logger.debug("Returned ClickZetta connection to pool")
|
||||
else:
|
||||
# Pool full or connection invalid, close it
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _cleanup_expired_connections(self) -> None:
|
||||
"""Clean up expired connections from all pools."""
|
||||
|
|
@ -265,10 +260,8 @@ class ClickzettaConnectionPool:
|
|||
if current_time - last_used < self._connection_timeout:
|
||||
valid_connections.append((connection, last_used))
|
||||
else:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._pools[config_key] = valid_connections
|
||||
|
||||
|
|
@ -299,10 +292,8 @@ class ClickzettaConnectionPool:
|
|||
with self._pool_locks[config_key]:
|
||||
pool = self._pools[config_key]
|
||||
for connection, _ in pool:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
pool.clear()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional, cast
|
||||
|
||||
|
|
@ -25,12 +26,10 @@ class PdfExtractor(BaseExtractor):
|
|||
def extract(self) -> list[Document]:
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
documents = list(self.load())
|
||||
text_list = []
|
||||
for document in documents:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import base64
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -33,7 +34,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
|||
elements = partition_email(filename=self._file_path)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
for element in elements:
|
||||
element_text = element.text.strip()
|
||||
|
||||
|
|
@ -43,8 +44,6 @@ class UnstructuredEmailExtractor(BaseExtractor):
|
|||
element_decode = base64.b64decode(element_text)
|
||||
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
|
||||
element.text = soup.get_text()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import base64
|
||||
import contextlib
|
||||
import enum
|
||||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
|
|
@ -227,10 +228,8 @@ class ToolInvokeMessage(BaseModel):
|
|||
@classmethod
|
||||
def decode_blob_message(cls, v):
|
||||
if isinstance(v, dict) and "blob" in v:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
v["blob"] = base64.b64decode(v["blob"])
|
||||
except Exception:
|
||||
pass
|
||||
return v
|
||||
|
||||
@field_serializer("message")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
from collections.abc import Generator, Iterable
|
||||
from copy import deepcopy
|
||||
|
|
@ -69,10 +70,8 @@ class ToolEngine:
|
|||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {parameters[0].name: tool_parameters}
|
||||
else:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
tool_parameters = json.loads(tool_parameters)
|
||||
except Exception:
|
||||
pass
|
||||
if not isinstance(tool_parameters, dict):
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
|
|
@ -270,14 +269,12 @@ class ToolEngine:
|
|||
if response.meta.get("mime_type"):
|
||||
mimetype = response.meta.get("mime_type")
|
||||
else:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f"a{extension}")
|
||||
if guess_type_result:
|
||||
mimetype = guess_type_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not mimetype:
|
||||
mimetype = "image/jpeg"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -137,11 +138,9 @@ class ToolParameterConfigurationManager:
|
|||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
has_secret_input = True
|
||||
with contextlib.suppress(Exception):
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
|
|
@ -111,14 +112,12 @@ class ProviderConfigEncrypter:
|
|||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
|
@ -666,10 +667,8 @@ class ParameterExtractorNode(BaseNode):
|
|||
if result[idx] == "{" or result[idx] == "[":
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
|
|
@ -686,10 +685,9 @@ class ParameterExtractorNode(BaseNode):
|
|||
if result[idx] == "{" or result[idx] == "[":
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
|
@ -38,12 +39,11 @@ def handle(sender, **kwargs):
|
|||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import atexit
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
|
|
@ -106,7 +107,7 @@ def init_app(app: DifyApp):
|
|||
"""Custom logging handler that creates spans for logging.exception() calls"""
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
if record.exc_info:
|
||||
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
|
||||
with tracer.start_as_current_span(
|
||||
|
|
@ -126,9 +127,6 @@ def init_app(app: DifyApp):
|
|||
if record.exc_info[0]:
|
||||
span.set_attribute("exception.type", record.exc_info[0].__name__)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
|
|
@ -142,13 +143,11 @@ class ConversationService:
|
|||
raise MessageNotExistsError()
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
name = LLMGenerator.generate_conversation_name(
|
||||
app_model.tenant_id, message.query, conversation.id, app_model.id
|
||||
)
|
||||
conversation.name = name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
|
@ -44,10 +45,8 @@ class TestClickzettaVector(AbstractVectorTest):
|
|||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
|
|
@ -124,13 +125,10 @@ def test_sse_client_connection_validation():
|
|||
mock_event_source.iter_sse.return_value = [endpoint_event]
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(test_url) as (read_queue, write_queue):
|
||||
assert read_queue is not None
|
||||
assert write_queue is not None
|
||||
except Exception as e:
|
||||
# Connection might fail due to mocking, but we're testing the validation logic
|
||||
pass
|
||||
|
||||
|
||||
def test_sse_client_error_handling():
|
||||
|
|
@ -178,7 +176,7 @@ def test_sse_client_timeout_configuration():
|
|||
mock_event_source.iter_sse.return_value = []
|
||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(
|
||||
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
|
||||
) as (read_queue, write_queue):
|
||||
|
|
@ -190,9 +188,6 @@ def test_sse_client_timeout_configuration():
|
|||
assert call_args is not None
|
||||
timeout_arg = call_args[1]["timeout"]
|
||||
assert timeout_arg.read == custom_sse_timeout
|
||||
except Exception:
|
||||
# Connection might fail due to mocking, but we tested the configuration
|
||||
pass
|
||||
|
||||
|
||||
def test_sse_transport_endpoint_validation():
|
||||
|
|
@ -251,12 +246,10 @@ def test_sse_client_queue_cleanup():
|
|||
# Mock connection that raises an exception
|
||||
mock_sse_connect.side_effect = Exception("Connection failed")
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(test_url) as (rq, wq):
|
||||
read_queue = rq
|
||||
write_queue = wq
|
||||
except Exception:
|
||||
pass # Expected to fail
|
||||
|
||||
# Queues should be cleaned up even on exception
|
||||
# Note: In real implementation, cleanup should put None to signal shutdown
|
||||
|
|
@ -283,11 +276,9 @@ def test_sse_client_headers_propagation():
|
|||
mock_event_source.iter_sse.return_value = []
|
||||
mock_sse_connect.return_value.__enter__.return_value = mock_event_source
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
with sse_client(test_url, headers=custom_headers):
|
||||
pass
|
||||
except Exception:
|
||||
pass # Expected due to mocking
|
||||
|
||||
# Verify headers were passed to client factory
|
||||
mock_client_factory.assert_called_with(headers=custom_headers)
|
||||
|
|
|
|||
Loading…
Reference in New Issue