clan_lib: typecast return of get_value_by_path
This commit is contained in:
@@ -9,6 +9,7 @@ from clan_lib.machines.machines import Machine
|
||||
from clan_lib.nix_models.clan import (
|
||||
InventoryInstance,
|
||||
InventoryMachine,
|
||||
InventoryMachineTagsType,
|
||||
)
|
||||
from clan_lib.persist.inventory_store import InventoryStore
|
||||
from clan_lib.persist.util import (
|
||||
@@ -181,11 +182,11 @@ def get_machine_fields_schema(machine: Machine) -> dict[str, FieldSchema]:
|
||||
# TODO: handle this more generically. I.e via json schema
|
||||
persisted_data = inventory_store._get_persisted() # noqa: SLF001
|
||||
inventory = inventory_store.read()
|
||||
all_tags = get_value_by_path(inventory, f"machines.{machine.name}.tags", [])
|
||||
all_tags = get_value_by_path(
|
||||
inventory, f"machines.{machine.name}.tags", [], InventoryMachineTagsType
|
||||
)
|
||||
persisted_tags = get_value_by_path(
|
||||
persisted_data,
|
||||
f"machines.{machine.name}.tags",
|
||||
[],
|
||||
persisted_data, f"machines.{machine.name}.tags", [], InventoryMachineTagsType
|
||||
)
|
||||
nix_tags = list_difference(all_tags, persisted_tags)
|
||||
|
||||
|
||||
@@ -9,7 +9,12 @@ from clan_lib.errors import ClanError
|
||||
from clan_lib.flake import Flake
|
||||
from clan_lib.machines import actions as actions_module
|
||||
from clan_lib.machines.machines import Machine
|
||||
from clan_lib.nix_models.clan import Clan, InventoryMachine, Unknown
|
||||
from clan_lib.nix_models.clan import (
|
||||
Clan,
|
||||
InventoryMachine,
|
||||
InventoryMachineTagsType,
|
||||
Unknown,
|
||||
)
|
||||
from clan_lib.persist.inventory_store import InventoryStore
|
||||
from clan_lib.persist.util import get_value_by_path, set_value_by_path
|
||||
|
||||
@@ -233,7 +238,9 @@ def test_get_machine_writeability(clan_flake: Callable[..., Flake]) -> None:
|
||||
# TODO: Move this into the api
|
||||
inventory_store = InventoryStore(flake=flake)
|
||||
inventory = inventory_store.read()
|
||||
curr_tags = get_value_by_path(inventory, "machines.jon.tags", [])
|
||||
curr_tags = get_value_by_path(
|
||||
inventory, "machines.jon.tags", [], InventoryMachineTagsType
|
||||
)
|
||||
new_tags = ["managed1", "managed2"]
|
||||
set_value_by_path(inventory, "machines.jon.tags", [*curr_tags, *new_tags])
|
||||
inventory_store.write(inventory, message="Test writeability")
|
||||
|
||||
@@ -441,7 +441,15 @@ def delete_by_path(d: dict[str, Any], path: str) -> Any:
|
||||
type DictLike = dict[str, Any] | Any
|
||||
|
||||
|
||||
def get_value_by_path(d: DictLike, path: str, fallback: Any = None) -> Any:
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
def get_value_by_path(
|
||||
d: DictLike,
|
||||
path: str,
|
||||
fallback: V | None = None,
|
||||
expected_type: type[V] | None = None, # noqa: ARG001
|
||||
) -> V:
|
||||
"""Get the value at a specific dot-separated path in a nested dictionary.
|
||||
|
||||
If the path does not exist, it returns fallback.
|
||||
@@ -455,9 +463,9 @@ def get_value_by_path(d: DictLike, path: str, fallback: Any = None) -> Any:
|
||||
current = current.setdefault(key, {})
|
||||
|
||||
if isinstance(current, dict):
|
||||
return current.get(keys[-1], fallback)
|
||||
return cast("V", current.get(keys[-1], fallback))
|
||||
|
||||
return fallback
|
||||
return cast("V", fallback)
|
||||
|
||||
|
||||
def set_value_by_path(d: DictLike, path: str, content: Any) -> None:
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Callable
|
||||
import pytest
|
||||
|
||||
from clan_lib.flake import Flake
|
||||
from clan_lib.nix_models.clan import InventoryMachineTagsType
|
||||
from clan_lib.persist.inventory_store import InventoryStore
|
||||
from clan_lib.persist.util import get_value_by_path, set_value_by_path
|
||||
from clan_lib.tags.list import list_tags
|
||||
@@ -45,7 +46,9 @@ def test_list_inventory_tags(clan_flake: Callable[..., Flake]) -> None:
|
||||
|
||||
inventory_store = InventoryStore(flake=flake)
|
||||
inventory = inventory_store.read()
|
||||
curr_tags = get_value_by_path(inventory, "machines.jon.tags", [])
|
||||
curr_tags = get_value_by_path(
|
||||
inventory, "machines.jon.tags", [], InventoryMachineTagsType
|
||||
)
|
||||
new_tags = ["managed1", "managed2"]
|
||||
set_value_by_path(inventory, "machines.jon.tags", [*curr_tags, *new_tags])
|
||||
inventory_store.write(inventory, message="Test add tags via API")
|
||||
|
||||
Reference in New Issue
Block a user