vars: refactor: use Machine objects instead of base_dir strings

Replace base_dir string parameters with Machine objects throughout the vars
module for better type safety and consistency.
This commit is contained in:
DavHau
2025-08-19 18:30:53 +07:00
parent bdab3e23af
commit 7b61a668e9
6 changed files with 25 additions and 38 deletions

View File

@@ -187,16 +187,9 @@ def test_generate_public_and_secret_vars(
"Update vars via generator my_shared_generator for machine my_machine"
in commit_message
)
assert get_machine_var(machine, "my_generator/my_value").printable_value == "public"
assert (
get_machine_var(
str(machine.flake.path), machine.name, "my_generator/my_value"
).printable_value
== "public"
)
assert (
get_machine_var(
str(machine.flake.path), machine.name, "my_shared_generator/my_shared_value"
).printable_value
get_machine_var(machine, "my_shared_generator/my_shared_value").printable_value
== "shared"
)
vars_text = stringify_all_vars(machine)
@@ -953,29 +946,21 @@ def test_invalidation(
monkeypatch.chdir(flake.path)
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
value1 = get_machine_var(
str(machine.flake.path), machine.name, "my_generator/my_value"
).printable_value
value1 = get_machine_var(machine, "my_generator/my_value").printable_value
# generate again and make sure nothing changes without the invalidation data being set
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
value1_new = get_machine_var(
str(machine.flake.path), machine.name, "my_generator/my_value"
).printable_value
value1_new = get_machine_var(machine, "my_generator/my_value").printable_value
assert value1 == value1_new
# set the invalidation data of the generator
my_generator["validation"] = 1
flake.refresh()
# generate again and make sure the value changes
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
value2 = get_machine_var(
str(machine.flake.path), machine.name, "my_generator/my_value"
).printable_value
value2 = get_machine_var(machine, "my_generator/my_value").printable_value
assert value1 != value2
# generate again without changing invalidation data -> value should not change
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
value2_new = get_machine_var(
str(machine.flake.path), machine.name, "my_generator/my_value"
).printable_value
value2_new = get_machine_var(machine, "my_generator/my_value").printable_value
assert value2 == value2_new

View File

@@ -9,6 +9,7 @@ from clan_cli.completions import (
)
from clan_lib.errors import ClanError
from clan_lib.flake import Flake, require_flake
from clan_lib.machines.machines import Machine
from .generator import Var
from .list import get_machine_vars
@@ -16,9 +17,9 @@ from .list import get_machine_vars
log = logging.getLogger(__name__)
def get_machine_var(base_dir: str, machine_name: str, var_id: str) -> Var:
log.debug(f"getting var: {var_id} from machine: {machine_name}")
vars_ = get_machine_vars(base_dir=base_dir, machine_name=machine_name)
def get_machine_var(machine: Machine, var_id: str) -> Var:
log.debug(f"getting var: {var_id} from machine: {machine.name}")
vars_ = get_machine_vars(machine)
results = []
for var in vars_:
if var.id == var_id:
@@ -44,7 +45,8 @@ def get_machine_var(base_dir: str, machine_name: str, var_id: str) -> Var:
def get_command(machine_name: str, var_id: str, flake: Flake) -> None:
var = get_machine_var(str(flake.path), machine_name, var_id)
machine = Machine(name=machine_name, flake=flake)
var = get_machine_var(machine, var_id)
if not var.exists:
msg = f"Var {var.id} has not been generated yet"
raise ClanError(msg)

View File

@@ -2,7 +2,7 @@ import argparse
import logging
from clan_cli.completions import add_dynamic_completer, complete_machines
from clan_lib.flake import Flake, require_flake
from clan_lib.flake import require_flake
from clan_lib.machines.machines import Machine
from clan_lib.vars.generate import get_generators
@@ -11,10 +11,9 @@ from .generator import Var
log = logging.getLogger(__name__)
def get_machine_vars(base_dir: str, machine_name: str) -> list[Var]:
def get_machine_vars(machine: Machine) -> list[Var]:
# TODO: We dont have machine level store / this granularity yet
# We should move the store definition to the flake, as there can be only one store per clan
machine = Machine(name=machine_name, flake=Flake(base_dir))
pub_store = machine.public_vars_store
sec_store = machine.secret_vars_store
@@ -37,7 +36,7 @@ def stringify_vars(_vars: list[Var]) -> str:
def stringify_all_vars(machine: Machine) -> str:
return stringify_vars(get_machine_vars(str(machine.flake), machine.name))
return stringify_vars(get_machine_vars(machine))
def list_command(args: argparse.Namespace) -> None:

View File

@@ -25,7 +25,7 @@ def set_var(machine: str | Machine, var: str | Var, value: bytes, flake: Flake)
else:
_machine = machine
if isinstance(var, str):
_var = get_machine_var(str(flake.path), _machine.name, var)
_var = get_machine_var(_machine, var)
else:
_var = var
path = _var.set(value)
@@ -39,7 +39,7 @@ def set_var(machine: str | Machine, var: str | Var, value: bytes, flake: Flake)
def set_via_stdin(machine_name: str, var_id: str, flake: Flake) -> None:
machine = Machine(name=machine_name, flake=flake)
var = get_machine_var(str(flake.path), machine_name, var_id)
var = get_machine_var(machine, var_id)
if sys.stdin.isatty():
new_value = ask(
var.id,

View File

@@ -32,11 +32,11 @@ class Peer:
_var: dict[str, str] = self._host["var"]
machine_name = _var["machine"]
generator = _var["generator"]
from clan_lib.machines.machines import Machine
machine = Machine(name=machine_name, flake=self.flake)
var = get_machine_var(
str(
self.flake
), # TODO we should really pass the flake instance here instead of a str representation
machine_name,
machine,
f"{generator}/{_var['file']}",
)
if not var.exists:

View File

@@ -2,6 +2,7 @@ from typing import Any
from unittest.mock import MagicMock, patch
from clan_lib.flake import Flake
from clan_lib.machines.machines import Machine
from clan_lib.network.network import Network, Peer, networks_from_flake
@@ -11,12 +12,12 @@ def test_networks_from_flake(mock_get_machine_var: MagicMock) -> None:
flake = MagicMock(spec=Flake)
# Mock the var decryption
def mock_var_side_effect(flake_path: str, machine: str, var_path: str) -> Any:
if machine == "machine1" and var_path == "wireguard/address":
def mock_var_side_effect(machine: Machine, var_path: str) -> Any:
if machine.name == "machine1" and var_path == "wireguard/address":
mock_var = MagicMock()
mock_var.value.decode.return_value = "192.168.1.10"
return mock_var
if machine == "machine2" and var_path == "wireguard/address":
if machine.name == "machine2" and var_path == "wireguard/address":
mock_var = MagicMock()
mock_var.value.decode.return_value = "192.168.1.11"
return mock_var