diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index c689a86614..aa39e6b681 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -19,6 +19,7 @@ from typing_extensions import TypedDict from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( OPS_FILE_PATH, + BaseTracingConfig, TracingProviderEnum, ) from core.ops.entities.trace_entity import ( @@ -195,8 +196,15 @@ def _lookup_llm_credential_info( return None, "" -class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, provider: str) -> dict[str, Any]: +class TracingProviderConfigEntry(TypedDict): + config_class: type[BaseTracingConfig] + secret_keys: list[str] + other_keys: list[str] + trace_instance: type[Any] + + +class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): + def __getitem__(self, provider: str) -> TracingProviderConfigEntry: match provider: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig @@ -585,8 +593,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).api_check() + config = config_type(**tracing_config) + return trace_instance(config).api_check() @staticmethod def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): @@ -600,8 +608,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).get_project_key() + config = config_type(**tracing_config) + return trace_instance(config).get_project_key() @staticmethod def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): @@ -615,8 +623,8 @@ class OpsTraceManager: provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["trace_instance"], ) - tracing_config = config_type(**tracing_config) - return trace_instance(tracing_config).get_project_url() + config = config_type(**tracing_config) + return trace_instance(config).get_project_url() class TraceTask: diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 2a64088dd6..0db3d3efec 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,9 +1,7 @@ -from typing import Any - from sqlalchemy import select from core.ops.entities.config_entity import BaseTracingConfig -from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map +from core.ops.ops_trace_manager import OpsTraceManager, TracingProviderConfigEntry, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -150,7 +148,7 @@ class OpsService: except KeyError: return {"error": f"Invalid tracing provider: {tracing_provider}"} - provider_config: dict[str, Any] = provider_config_map[tracing_provider] + provider_config: TracingProviderConfigEntry = provider_config_map[tracing_provider] config_class: type[BaseTracingConfig] = provider_config["config_class"] other_keys: list[str] = provider_config["other_keys"]