From e7c3e4cd21b2796729dcf8355145c7a21a6e5013 Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 22 Jan 2026 15:05:35 +0800 Subject: [PATCH] feat: introduce attribute management system for sandbox - Added AttrMap and AttrKey classes for type-safe attribute storage. - Implemented AppAssetsAttrs and SkillAttrs for managing application and skill attributes. - Refactored Sandbox and initializers to utilize the new attribute management system, enhancing modularity and clarity in asset handling. --- api/core/app_assets/constants.py | 7 + api/core/sandbox/builder.py | 6 +- .../initializer/app_assets_initializer.py | 43 +-- api/core/sandbox/initializer/base.py | 4 +- .../initializer/dify_cli_initializer.py | 19 +- .../sandbox/initializer/skill_initializer.py | 43 +++ api/core/sandbox/manager.py | 4 +- api/core/sandbox/sandbox.py | 7 + api/core/skill/__init__.py | 2 + api/core/skill/constants.py | 7 + api/libs/attr_map.py | 164 ++++++++++++ api/services/app_asset_service.py | 16 ++ api/tests/unit_tests/libs/test_attr_map.py | 251 ++++++++++++++++++ 13 files changed, 524 insertions(+), 49 deletions(-) create mode 100644 api/core/app_assets/constants.py create mode 100644 api/core/sandbox/initializer/skill_initializer.py create mode 100644 api/core/skill/constants.py create mode 100644 api/libs/attr_map.py create mode 100644 api/tests/unit_tests/libs/test_attr_map.py diff --git a/api/core/app_assets/constants.py b/api/core/app_assets/constants.py new file mode 100644 index 0000000000..a13c6cd0a9 --- /dev/null +++ b/api/core/app_assets/constants.py @@ -0,0 +1,7 @@ +from core.app.entities.app_asset_entities import AppAssetFileTree +from libs.attr_map import AttrKey + + +class AppAssetsAttrs: + # Skill artifact set + FILE_TREE = AttrKey("file_tree", AppAssetFileTree) diff --git a/api/core/sandbox/builder.py b/api/core/sandbox/builder.py index 6113fbc8e6..bb59523a98 100644 --- a/api/core/sandbox/builder.py +++ b/api/core/sandbox/builder.py @@ -100,9 +100,6 @@ class SandboxBuilder: environments=self._environments, user_id=self._user_id, ) - for init in self._initializers: - init.initialize(vm) - sandbox = Sandbox( vm=vm, storage=self._storage, @@ -111,6 +108,9 @@ class SandboxBuilder: app_id=self._app_id, assets_id=self._assets_id, ) + for init in self._initializers: + init.initialize(sandbox) + sandbox.mount() return sandbox diff --git a/api/core/sandbox/initializer/app_assets_initializer.py b/api/core/sandbox/initializer/app_assets_initializer.py index 3f990d29ee..da6d9ec935 100644 --- a/api/core/sandbox/initializer/app_assets_initializer.py +++ b/api/core/sandbox/initializer/app_assets_initializer.py @@ -1,10 +1,12 @@ import logging +from core.app_assets.constants import AppAssetsAttrs from core.app_assets.paths import AssetPaths +from core.sandbox.sandbox import Sandbox from core.virtual_environment.__base.helpers import pipeline -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from extensions.ext_storage import storage from extensions.storage.file_presign_storage import FilePresignStorage +from services.app_asset_service import AppAssetService from ..entities import AppAssets from .base import SandboxInitializer @@ -20,42 +22,17 @@ class AppAssetsInitializer(SandboxInitializer): self._app_id = app_id self._assets_id = assets_id - def initialize(self, env: VirtualEnvironment) -> None: + def initialize(self, sandbox: Sandbox) -> None: + vm = sandbox.vm + # load app assets + app_assets = AppAssetService.get_tenant_app_assets(self._tenant_id, self._assets_id) + sandbox.attrs.set(AppAssetsAttrs.FILE_TREE, app_assets.asset_tree) + zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id) download_url = FilePresignStorage(storage.storage_runner).get_download_url(zip_key) ( - pipeline(env) - .add(["wget", "-q", download_url, "-O", AppAssets.ZIP_PATH], error_message="Failed to download assets zip") - # unzip with silent error and return 1 if the zip is empty - # FIXME(Mairuis): should use a more robust way to check if the zip is empty - .add( - ["sh", "-c", f"unzip {AppAssets.ZIP_PATH} -d {AppAssets.PATH} 2>/dev/null || [ $? -eq 1 ]"], - error_message="Failed to unzip assets", - ) - .execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True) - ) - - logger.info( - "App assets initialized for app_id=%s, published_id=%s", - self._app_id, - self._assets_id, - ) - - -class DraftAppAssetsInitializer(SandboxInitializer): - def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None: - self._tenant_id = tenant_id - self._app_id = app_id - self._assets_id = assets_id - - def initialize(self, env: VirtualEnvironment) -> None: - zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id) - download_url = FilePresignStorage(storage.storage_runner).get_download_url(zip_key) - - ( - pipeline(env) - .add(["rm", "-rf", AppAssets.PATH]) + pipeline(vm) .add(["wget", "-q", download_url, "-O", AppAssets.ZIP_PATH], error_message="Failed to download assets zip") # unzip with silent error and return 1 if the zip is empty # FIXME(Mairuis): should use a more robust way to check if the zip is empty diff --git a/api/core/sandbox/initializer/base.py b/api/core/sandbox/initializer/base.py index 937b09c2dc..c2213d9714 100644 --- a/api/core/sandbox/initializer/base.py +++ b/api/core/sandbox/initializer/base.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from core.sandbox.sandbox import Sandbox class SandboxInitializer(ABC): @abstractmethod - def initialize(self, env: VirtualEnvironment) -> None: ... + def initialize(self, env: Sandbox) -> None: ... diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py index fa6f57ce69..0cc77ee0d8 100644 --- a/api/core/sandbox/initializer/dify_cli_initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -5,10 +5,10 @@ import logging from io import BytesIO from pathlib import Path +from core.sandbox.sandbox import Sandbox from core.session.cli_api import CliApiSessionManager from core.skill.skill_manager import SkillManager from core.virtual_environment.__base.helpers import pipeline -from core.virtual_environment.__base.virtual_environment import VirtualEnvironment from ..bash.dify_cli import DifyCliConfig, DifyCliLocator from ..entities import DifyCli @@ -35,18 +35,19 @@ class DifyCliInitializer(SandboxInitializer): self._tools = [] self._cli_api_session = None - def initialize(self, env: VirtualEnvironment) -> None: - binary = self._locator.resolve(env.metadata.os, env.metadata.arch) + def initialize(self, sandbox: Sandbox) -> None: + vm = sandbox.vm + binary = self._locator.resolve(vm.metadata.os, vm.metadata.arch) - pipeline(env).add( + pipeline(vm).add( ["mkdir", "-p", f"{DifyCli.ROOT}/bin"], error_message="Failed to create dify CLI directory" ).execute(raise_on_error=True) - env.upload_file(DifyCli.PATH, BytesIO(binary.path.read_bytes())) + vm.upload_file(DifyCli.PATH, BytesIO(binary.path.read_bytes())) # Use 'cp' with mode preservation workaround: copy file to itself to claim ownership, # then use 'install' to set executable permission - pipeline(env).add( + pipeline(vm).add( [ "sh", "-c", @@ -67,16 +68,16 @@ class DifyCliInitializer(SandboxInitializer): # FIXME(Mairuis): store it in workflow context self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id) - pipeline(env).add( + pipeline(vm).add( ["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir" ).execute(raise_on_error=True) config = DifyCliConfig.create(self._cli_api_session, self._tenant_id, artifact) config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False) config_path = f"{DifyCli.GLOBAL_TOOLS_PATH}/{DifyCli.CONFIG_FILENAME}" - env.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) + vm.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) - pipeline(env, cwd=DifyCli.GLOBAL_TOOLS_PATH).add( + pipeline(vm, cwd=DifyCli.GLOBAL_TOOLS_PATH).add( [DifyCli.PATH, "init"], error_message="Failed to initialize Dify CLI" ).execute(raise_on_error=True) diff --git a/api/core/sandbox/initializer/skill_initializer.py b/api/core/sandbox/initializer/skill_initializer.py new file mode 100644 index 0000000000..5a0b0adf81 --- /dev/null +++ b/api/core/sandbox/initializer/skill_initializer.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import logging + +from core.sandbox.sandbox import Sandbox +from core.skill import SkillAttrs +from core.skill.skill_manager import SkillManager + +from .base import SandboxInitializer + +logger = logging.getLogger(__name__) + + +class SkillInitializer(SandboxInitializer): + def __init__( + self, + tenant_id: str, + user_id: str, + app_id: str, + assets_id: str, + ) -> None: + self._tenant_id = tenant_id + self._app_id = app_id + self._user_id = user_id + self._assets_id = assets_id + + def initialize(self, sandbox: Sandbox) -> None: + artifact_set = SkillManager.load_artifact( + self._tenant_id, + self._app_id, + self._assets_id, + ) + if artifact_set is None: + raise ValueError( + f"No skill artifact set found for tenant_id={self._tenant_id}," + f"app_id={self._app_id}, " + f"assets_id={self._assets_id} " + ) + + sandbox.attrs.set( + SkillAttrs.ARTIFACT_SET, + artifact_set, + ) diff --git a/api/core/sandbox/manager.py b/api/core/sandbox/manager.py index 01a824ad4a..ee50f9502f 100644 --- a/api/core/sandbox/manager.py +++ b/api/core/sandbox/manager.py @@ -7,7 +7,7 @@ from typing import Final from core.sandbox.builder import SandboxBuilder from core.sandbox.entities import AppAssets, SandboxType from core.sandbox.entities.providers import SandboxProviderEntity -from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer, DraftAppAssetsInitializer +from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer from core.sandbox.sandbox import Sandbox from core.sandbox.storage.archive_storage import ArchiveSandboxStorage @@ -151,7 +151,7 @@ class SandboxManager: .options(sandbox_provider.config) .user(user_id) .app(app_id) - .initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id)) + .initializer(AppAssetsInitializer(tenant_id, app_id, assets.id)) .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) .storage(storage, assets.id) .build() diff --git a/api/core/sandbox/sandbox.py b/api/core/sandbox/sandbox.py index 7032238277..1fc61895e1 100644 --- a/api/core/sandbox/sandbox.py +++ b/api/core/sandbox/sandbox.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging from typing import TYPE_CHECKING +from libs.attr_map import AttrMap + if TYPE_CHECKING: from core.sandbox.storage.sandbox_storage import SandboxStorage from core.virtual_environment.__base.virtual_environment import VirtualEnvironment @@ -27,6 +29,11 @@ class Sandbox: self._user_id = user_id self._app_id = app_id self._assets_id = assets_id + self._attributes = AttrMap() + + @property + def attrs(self) -> AttrMap: + return self._attributes @property def vm(self) -> VirtualEnvironment: diff --git a/api/core/skill/__init__.py b/api/core/skill/__init__.py index 6515a514ab..c45eb27ccf 100644 --- a/api/core/skill/__init__.py +++ b/api/core/skill/__init__.py @@ -1,7 +1,9 @@ +from .constants import SkillAttrs from .entities import ToolArtifact, ToolDependency, ToolReference from .skill_manager import SkillManager __all__ = [ + "SkillAttrs", "SkillManager", "ToolArtifact", "ToolDependency", diff --git a/api/core/skill/constants.py b/api/core/skill/constants.py new file mode 100644 index 0000000000..9a1e3327c4 --- /dev/null +++ b/api/core/skill/constants.py @@ -0,0 +1,7 @@ +from core.skill.entities.skill_artifact_set import SkillArtifactSet +from libs.attr_map import AttrKey + + +class SkillAttrs: + # Skill artifact set + ARTIFACT_SET = AttrKey("skill_artifact_set", SkillArtifactSet) diff --git a/api/libs/attr_map.py b/api/libs/attr_map.py new file mode 100644 index 0000000000..a200d72d42 --- /dev/null +++ b/api/libs/attr_map.py @@ -0,0 +1,164 @@ +""" +Type-safe attribute storage inspired by Netty's AttributeKey/AttributeMap pattern. + +Provides loosely-coupled typed attribute storage where only code with access +to the same AttrKey instance can read/write the corresponding attribute. + + SESSION_KEY: AttrKey[Session] = AttrKey("session", Session) + attrs = AttrMap() + attrs.set(SESSION_KEY, session) + session = attrs.get(SESSION_KEY) # -> Session | None + session = attrs.require(SESSION_KEY) # -> Session (raises if not set) + +Note: AttrMap is NOT thread-safe. Each instance should be confined to a single +thread/context (e.g., one AttrMap per Sandbox/VirtualEnvironment instance). +""" + +from __future__ import annotations + +from typing import Any, Generic, TypeVar, cast, final, overload + +T = TypeVar("T") +D = TypeVar("D") + + +@final +class AttrKey(Generic[T]): + """ + A type-safe key for attribute storage. + + Identity-based: different AttrKey instances with same name are distinct keys. + This enables different modules to define keys independently without collision. + """ + + __slots__ = ("_name", "_type") + + def __init__(self, name: str, type_: type[T]) -> None: + self._name = name + self._type = type_ + + @property + def name(self) -> str: + return self._name + + @property + def type_(self) -> type[T]: + return self._type + + def __repr__(self) -> str: + return f"AttrKey({self._name!r}, {self._type.__name__})" + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: object) -> bool: + return self is other + + +class AttrMapKeyError(KeyError): + """Raised when a required attribute is not set.""" + + key: AttrKey[Any] + + def __init__(self, key: AttrKey[Any]) -> None: + self.key = key + super().__init__(f"Required attribute '{key.name}' (type: {key.type_.__name__}) is not set") + + +class AttrMapTypeError(TypeError): + """Raised when attribute value type doesn't match the key's declared type.""" + + key: AttrKey[Any] + expected_type: type[Any] + actual_type: type[Any] + + def __init__(self, key: AttrKey[Any], expected_type: type[Any], actual_type: type[Any]) -> None: + self.key = key + self.expected_type = expected_type + self.actual_type = actual_type + super().__init__( + f"Attribute '{key.name}' expects type '{expected_type.__name__}', " + f"got '{actual_type.__name__}'" + ) + + +@final +class AttrMap: + """ + Thread-confined container for storing typed attributes using AttrKey instances. + + NOT thread-safe. Each instance should be owned by a single context + (e.g., one AttrMap per Sandbox/VirtualEnvironment instance). + """ + + __slots__ = ("_data",) + + def __init__(self) -> None: + self._data: dict[AttrKey[Any], Any] = {} + + def set(self, key: AttrKey[T], value: T, *, validate: bool = True) -> None: + """ + Store an attribute. Raises AttrMapTypeError if validate=True and type mismatches. + + Note: Runtime validation only checks outer type (e.g., `list` not `list[str]`). + """ + if validate and not isinstance(value, key.type_): + raise AttrMapTypeError(key, key.type_, type(value)) + self._data[key] = value + + def get(self, key: AttrKey[T]) -> T | None: + """Retrieve an attribute, returning None if not set.""" + return cast(T | None, self._data.get(key)) + + @overload + def get_or_default(self, key: AttrKey[T], default: T) -> T: ... + + @overload + def get_or_default(self, key: AttrKey[T], default: D) -> T | D: ... + + def get_or_default(self, key: AttrKey[T], default: T | D) -> T | D: + """Retrieve an attribute, returning default if not set.""" + if key in self._data: + return cast(T, self._data[key]) + return default + + def require(self, key: AttrKey[T]) -> T: + """Retrieve an attribute, raising AttrMapKeyError if not set.""" + if key not in self._data: + raise AttrMapKeyError(key) + return cast(T, self._data[key]) + + def has(self, key: AttrKey[Any]) -> bool: + """Check if an attribute is set.""" + return key in self._data + + def remove(self, key: AttrKey[Any]) -> bool: + """Remove an attribute. Returns True if it was present.""" + if key in self._data: + del self._data[key] + return True + return False + + def set_if_absent(self, key: AttrKey[T], value: T, *, validate: bool = True) -> T: + """ + Set attribute only if not already set. Returns existing or newly set value. + + Raises AttrMapTypeError if validate=True and type mismatches. + """ + if key in self._data: + return cast(T, self._data[key]) + if validate and not isinstance(value, key.type_): + raise AttrMapTypeError(key, key.type_, type(value)) + self._data[key] = value + return value + + def clear(self) -> None: + """Remove all attributes.""" + self._data.clear() + + def __len__(self) -> int: + return len(self._data) + + def __repr__(self) -> str: + keys = [k.name for k in self._data] + return f"AttrMap({keys})" diff --git a/api/services/app_asset_service.py b/api/services/app_asset_service.py index 76db7ce044..f2e046da0e 100644 --- a/api/services/app_asset_service.py +++ b/api/services/app_asset_service.py @@ -58,6 +58,22 @@ class AppAssetService: session.commit() return assets + @staticmethod + def get_tenant_app_assets(tenant_id: str, assets_id: str) -> AppAssets: + with Session(db.engine, expire_on_commit=False) as session: + app_assets = ( + session.query(AppAssets) + .filter( + AppAssets.tenant_id == tenant_id, + AppAssets.id == assets_id, + ) + .first() + ) + if not app_assets: + raise ValueError(f"App assets not found for tenant_id={tenant_id}, assets_id={assets_id}") + + return app_assets + @staticmethod def get_assets(tenant_id: str, app_id: str, user_id: str, *, is_draft: bool) -> AppAssets | None: with Session(db.engine, expire_on_commit=False) as session: diff --git a/api/tests/unit_tests/libs/test_attr_map.py b/api/tests/unit_tests/libs/test_attr_map.py new file mode 100644 index 0000000000..0e0bfcea1d --- /dev/null +++ b/api/tests/unit_tests/libs/test_attr_map.py @@ -0,0 +1,251 @@ +import pytest + +from libs.attr_map import AttrKey, AttrMap, AttrMapKeyError, AttrMapTypeError + + +class TestAttrKey: + def test_identity_based_equality(self): + key1 = AttrKey("session", str) + key2 = AttrKey("session", str) + + assert key1 != key2 + assert key1 == key1 + + def test_identity_based_hash(self): + key1 = AttrKey("session", str) + key2 = AttrKey("session", str) + + assert hash(key1) != hash(key2) + assert hash(key1) == hash(key1) + + def test_can_be_used_as_dict_key(self): + key1 = AttrKey("session", str) + key2 = AttrKey("session", str) + data: dict[AttrKey[str], str] = {} + + data[key1] = "value1" + data[key2] = "value2" + + assert data[key1] == "value1" + assert data[key2] == "value2" + assert len(data) == 2 + + def test_properties(self): + key = AttrKey("my_key", int) + + assert key.name == "my_key" + assert key.type_ is int + + def test_repr(self): + key = AttrKey("session", str) + assert repr(key) == "AttrKey('session', str)" + + +class TestAttrMap: + def test_set_and_get(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + attrs.set(key, "hello") + result = attrs.get(key) + + assert result == "hello" + + def test_get_returns_none_for_missing(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + assert attrs.get(key) is None + + def test_get_or_default_returns_value_when_set(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + attrs.set(key, "hello") + + result = attrs.get_or_default(key, "default") + + assert result == "hello" + + def test_get_or_default_returns_default_when_not_set(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + result = attrs.get_or_default(key, "default") + + assert result == "default" + + def test_require_returns_value_when_set(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + attrs.set(key, "hello") + + result = attrs.require(key) + + assert result == "hello" + + def test_require_raises_when_not_set(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + with pytest.raises(AttrMapKeyError) as exc_info: + attrs.require(key) + + assert exc_info.value.key is key + assert "session" in str(exc_info.value) + + def test_has(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + assert not attrs.has(key) + + attrs.set(key, "hello") + + assert attrs.has(key) + + def test_remove_existing(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + attrs.set(key, "hello") + + result = attrs.remove(key) + + assert result is True + assert not attrs.has(key) + + def test_remove_non_existing(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + result = attrs.remove(key) + + assert result is False + + def test_set_if_absent_when_absent(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + result = attrs.set_if_absent(key, "first") + + assert result == "first" + assert attrs.get(key) == "first" + + def test_set_if_absent_when_present(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + attrs.set(key, "existing") + + result = attrs.set_if_absent(key, "new") + + assert result == "existing" + assert attrs.get(key) == "existing" + + def test_clear(self): + key1: AttrKey[str] = AttrKey("key1", str) + key2: AttrKey[int] = AttrKey("key2", int) + attrs = AttrMap() + attrs.set(key1, "hello") + attrs.set(key2, 42) + + attrs.clear() + + assert len(attrs) == 0 + assert not attrs.has(key1) + assert not attrs.has(key2) + + def test_len(self): + key1: AttrKey[str] = AttrKey("key1", str) + key2: AttrKey[int] = AttrKey("key2", int) + attrs = AttrMap() + + assert len(attrs) == 0 + + attrs.set(key1, "hello") + assert len(attrs) == 1 + + attrs.set(key2, 42) + assert len(attrs) == 2 + + def test_repr(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + attrs.set(key, "hello") + + result = repr(attrs) + + assert "AttrMap" in result + assert "session" in result + + +class TestAttrMapTypeValidation: + def test_set_with_wrong_type_raises(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + with pytest.raises(AttrMapTypeError) as exc_info: + attrs.set(key, 123) # type: ignore[arg-type] + + assert exc_info.value.key is key + assert exc_info.value.expected_type is str + assert exc_info.value.actual_type is int + + def test_set_with_validate_false_allows_wrong_type(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + attrs.set(key, 123, validate=False) # type: ignore[arg-type] + + assert attrs.get(key) == 123 + + def test_set_if_absent_with_wrong_type_raises(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + with pytest.raises(AttrMapTypeError): + attrs.set_if_absent(key, 123) # type: ignore[arg-type] + + def test_set_if_absent_with_validate_false_allows_wrong_type(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + attrs.set_if_absent(key, 123, validate=False) # type: ignore[arg-type] + + assert attrs.get(key) == 123 + + def test_subclass_type_validation(self): + class Animal: + pass + + class Dog(Animal): + pass + + key: AttrKey[Animal] = AttrKey("animal", Animal) + attrs = AttrMap() + + attrs.set(key, Dog()) + + assert isinstance(attrs.get(key), Dog) + + +class TestAttrMapIsolation: + def test_different_keys_with_same_name_are_isolated(self): + key_in_module_a: AttrKey[str] = AttrKey("config", str) + key_in_module_b: AttrKey[str] = AttrKey("config", str) + attrs = AttrMap() + + attrs.set(key_in_module_a, "value_a") + attrs.set(key_in_module_b, "value_b") + + assert attrs.get(key_in_module_a) == "value_a" + assert attrs.get(key_in_module_b) == "value_b" + + def test_multiple_attr_maps_are_independent(self): + key: AttrKey[str] = AttrKey("session", str) + attrs1 = AttrMap() + attrs2 = AttrMap() + + attrs1.set(key, "map1") + attrs2.set(key, "map2") + + assert attrs1.get(key) == "map1" + assert attrs2.get(key) == "map2"