From 2d29345f2631c33b6ca55d00af4ac668378ff691 Mon Sep 17 00:00:00 2001 From: YBoy Date: Thu, 2 Apr 2026 03:47:08 +0200 Subject: [PATCH] =?UTF-8?q?refactor(api):=20type=20OpsTraceProviderConfigM?= =?UTF-8?q?ap=20with=20TracingProviderCon=E2=80=A6=20(#34424)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/core/ops/ops_trace_manager.py | 24 ++++++++++++++++-------- api/services/ops_service.py | 6 ++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index c689a86614..aa39e6b681 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -19,6 +19,7 @@ from typing_extensions import TypedDict from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( OPS_FILE_PATH, + BaseTracingConfig, TracingProviderEnum, ) from core.ops.entities.trace_entity import ( @@ -195,8 +196,15 @@ def _lookup_llm_credential_info( return None, "" -class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, provider: str) -> dict[str, Any]: +class TracingProviderConfigEntry(TypedDict): + config_class: type[BaseTracingConfig] + secret_keys: list[str] + other_keys: list[str] + trace_instance: type[Any] + + +class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): + def __getitem__(self, provider: str) -> TracingProviderConfigEntry: match provider: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig @@ -585,8 +593,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).api_check() + config = config_type(**tracing_config) + return trace_instance(config).api_check() @staticmethod def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): @@ -600,8 +608,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).get_project_key() + config = config_type(**tracing_config) + return trace_instance(config).get_project_key() @staticmethod def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): @@ -615,8 +623,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).get_project_url() + config = config_type(**tracing_config) + return trace_instance(config).get_project_url() class TraceTask: diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 2a64088dd6..0db3d3efec 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,9 +1,7 @@ -from typing import Any - from sqlalchemy import select from core.ops.entities.config_entity import BaseTracingConfig -from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map +from core.ops.ops_trace_manager import OpsTraceManager, TracingProviderConfigEntry, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -150,7 +148,7 @@ class OpsService: except KeyError: return {"error": f"Invalid tracing provider: {tracing_provider}"} - provider_config: dict[str, Any] = provider_config_map[tracing_provider] + provider_config: TracingProviderConfigEntry = provider_config_map[tracing_provider] config_class: type[BaseTracingConfig] = provider_config["config_class"] other_keys: list[str] = provider_config["other_keys"]