inventory/api: prepare list merging

This commit is contained in:
Johannes Kirschbauer
2024-12-05 17:43:35 +01:00
parent cb329900d9
commit 2d807c69e2
3 changed files with 74 additions and 14 deletions

View File

@@ -319,5 +319,4 @@ def from_dict(
msg = f"{data} is not a dict. Expected {t}" msg = f"{data} is not a dict. Expected {t}"
raise ClanError(msg) raise ClanError(msg)
return construct_dataclass(t, data, path) # type: ignore return construct_dataclass(t, data, path) # type: ignore
# breakpoint()
return construct_value(t, data, path) return construct_value(t, data, path)

View File

@@ -119,9 +119,26 @@ def flatten_data(data: dict, parent_key: str = "", separator: str = ".") -> dict
return flattened return flattened
def unmerge_lists(curr: list, prev: list) -> list:
"""
Unmerge the current list. Given a previous list.
Returns:
The other list.
"""
# Unmerge the lists
unmerged = []
for value in curr:
if value not in prev:
unmerged.append(value)
return unmerged
def determine_writeability( def determine_writeability(
data: dict, priorities: dict,
correlated: dict, defaults: dict,
persisted: dict,
parent_key: str = "", parent_key: str = "",
parent_prio: int | None = None, parent_prio: int | None = None,
results: dict | None = None, results: dict | None = None,
@@ -130,7 +147,7 @@ def determine_writeability(
if results is None: if results is None:
results = {"writeable": set({}), "non_writeable": set({})} results = {"writeable": set({}), "non_writeable": set({})}
for key, value in data.items(): for key, value in priorities.items():
if key == "__prio": if key == "__prio":
continue continue
@@ -149,6 +166,7 @@ def determine_writeability(
if isinstance(value, dict): if isinstance(value, dict):
determine_writeability( determine_writeability(
value, value,
defaults,
{}, # Children won't be writeable, so correlation doesn't matter here {}, # Children won't be writeable, so correlation doesn't matter here
full_key, full_key,
prio, # Pass the same priority down prio, # Pass the same priority down
@@ -159,13 +177,22 @@ def determine_writeability(
continue continue
# Check if the key is writeable otherwise # Check if the key is writeable otherwise
key_in_correlated = key in correlated key_in_correlated = key in persisted
if prio is None: if prio is None:
msg = f"Priority for key '{full_key}' is not defined. Cannot determine if it is writeable." msg = f"Priority for key '{full_key}' is not defined. Cannot determine if it is writeable."
raise ClanError(msg) raise ClanError(msg)
has_children = any(k != "__prio" for k in value) is_mergeable = False
is_writeable = prio > 100 or key_in_correlated or has_children if prio == 100:
default = defaults.get(key)
if isinstance(default, dict):
is_mergeable = True
if isinstance(default, list):
is_mergeable = True
if key_in_correlated:
is_mergeable = True
is_writeable = prio > 100 or is_mergeable
# Append the result # Append the result
if is_writeable: if is_writeable:
@@ -177,7 +204,8 @@ def determine_writeability(
if isinstance(value, dict): if isinstance(value, dict):
determine_writeability( determine_writeability(
value, value,
correlated.get(key, {}), defaults.get(key, {}),
persisted.get(key, {}),
full_key, full_key,
prio, # Pass down current priority prio, # Pass down current priority
results, results,

View File

@@ -1,5 +1,5 @@
# Functions to test # Functions to test
from clan_cli.inventory import determine_writeability, patch from clan_cli.inventory import determine_writeability, patch, unmerge_lists
# --------- Patching tests --------- # --------- Patching tests ---------
@@ -49,8 +49,9 @@ def test_write_simple() -> None:
}, },
} }
default: dict = {"foo": {}}
data: dict = {} data: dict = {}
res = determine_writeability(prios, data) res = determine_writeability(prios, default, data)
assert res == {"writeable": {"foo", "foo.bar"}, "non_writeable": set({})} assert res == {"writeable": {"foo", "foo.bar"}, "non_writeable": set({})}
@@ -67,7 +68,7 @@ def test_write_inherited() -> None:
} }
data: dict = {} data: dict = {}
res = determine_writeability(prios, data) res = determine_writeability(prios, {"foo": {"bar": {}}}, data)
assert res == { assert res == {
"writeable": {"foo", "foo.bar", "foo.bar.baz"}, "writeable": {"foo", "foo.bar", "foo.bar.baz"},
"non_writeable": set(), "non_writeable": set(),
@@ -86,13 +87,34 @@ def test_non_write_inherited() -> None:
} }
data: dict = {} data: dict = {}
res = determine_writeability(prios, data) res = determine_writeability(prios, {}, data)
assert res == { assert res == {
"writeable": set(), "writeable": set(),
"non_writeable": {"foo", "foo.bar", "foo.bar.baz"}, "non_writeable": {"foo", "foo.bar", "foo.bar.baz"},
} }
def test_write_list() -> None:
prios = {
"foo": {
"__prio": 100,
},
}
data: dict = {}
default: dict = {
"foo": [
"a",
"b",
] # <- writeable: because lists are merged. Filtering out nix-values comes later
}
res = determine_writeability(prios, default, data)
assert res == {
"writeable": {"foo"},
"non_writeable": set(),
}
def test_write_because_written() -> None: def test_write_because_written() -> None:
prios = { prios = {
"foo": { "foo": {
@@ -107,7 +129,7 @@ def test_write_because_written() -> None:
# Given the following data. {} # Given the following data. {}
# Check that the non-writeable paths are correct. # Check that the non-writeable paths are correct.
res = determine_writeability(prios, {}) res = determine_writeability(prios, {"foo": {"bar": {}}}, {})
assert res == { assert res == {
"writeable": {"foo", "foo.bar"}, "writeable": {"foo", "foo.bar"},
"non_writeable": {"foo.bar.baz", "foo.bar.foobar"}, "non_writeable": {"foo.bar.baz", "foo.bar.foobar"},
@@ -120,8 +142,19 @@ def test_write_because_written() -> None:
} }
} }
} }
res = determine_writeability(prios, data) res = determine_writeability(prios, {}, data)
assert res == { assert res == {
"writeable": {"foo", "foo.bar", "foo.bar.baz"}, "writeable": {"foo", "foo.bar", "foo.bar.baz"},
"non_writeable": {"foo.bar.foobar"}, "non_writeable": {"foo.bar.foobar"},
} }
# --------- List unmerge tests ---------
def test_list_unmerge() -> None:
all_machines = ["machineA", "machineB"]
inventory = ["machineB"]
nix_machines = unmerge_lists(inventory, all_machines)
assert nix_machines == ["machineA"]