vars: allow re-encrypting secrets when recipient keys were added.

When the users of a secret change, when for example a new admin user is added, an error will be thrown when generating vars, prompting the user to pass --fix to re-encrypt the secrets
This commit is contained in:
DavHau
2024-11-13 18:18:25 +07:00
parent 83b7c6d9a2
commit 54b8f5904e
9 changed files with 221 additions and 34 deletions

View File

@@ -60,7 +60,9 @@ def flash_machine(
extra_args = [] extra_args = []
system_config_nix: dict[str, Any] = {} system_config_nix: dict[str, Any] = {}
generate_vars_for_machine(machine, generator_name=None, regenerate=False) generate_vars_for_machine(
machine, generator_name=None, regenerate=False, fix=False
)
generate_facts([machine], service=None, regenerate=False) generate_facts([machine], service=None, regenerate=False)
if system_config.wifi_settings: if system_config.wifi_settings:

View File

@@ -103,7 +103,7 @@ def update_group_keys(flake_dir: Path, group: str) -> list[Path]:
if (secret / "groups" / group).is_symlink(): if (secret / "groups" / group).is_symlink():
updated_paths += update_keys( updated_paths += update_keys(
secret, secret,
sorted(secrets.collect_keys_for_path(secret)), secrets.collect_keys_for_path(secret),
) )
return updated_paths return updated_paths

View File

@@ -1,7 +1,5 @@
import argparse import argparse
import functools
import getpass import getpass
import operator
import os import os
import shutil import shutil
import sys import sys
@@ -45,7 +43,7 @@ def update_secrets(
changed_files.extend( changed_files.extend(
update_keys( update_keys(
secret_path, secret_path,
sorted_keys(collect_keys_for_path(secret_path)), collect_keys_for_path(secret_path),
) )
) )
return changed_files return changed_files
@@ -147,7 +145,7 @@ def encrypt_secret(
) )
secret_path = secret_path / "secret" secret_path = secret_path / "secret"
encrypt_file(secret_path, value, sorted_keys(recipient_keys)) encrypt_file(secret_path, value, sorted(recipient_keys))
files_to_commit.append(secret_path) files_to_commit.append(secret_path)
if git_commit: if git_commit:
commit_files( commit_files(
@@ -231,7 +229,7 @@ def allow_member(
changed.extend( changed.extend(
update_keys( update_keys(
group_folder.parent, group_folder.parent,
sorted_keys(collect_keys_for_path(group_folder.parent)), collect_keys_for_path(group_folder.parent),
) )
) )
return changed return changed
@@ -257,12 +255,7 @@ def disallow_member(group_folder: Path, name: str) -> list[Path]:
if len(os.listdir(group_folder.parent)) == 0: if len(os.listdir(group_folder.parent)) == 0:
group_folder.parent.rmdir() group_folder.parent.rmdir()
return update_keys( return update_keys(target.parent.parent, collect_keys_for_path(group_folder.parent))
target.parent.parent, sorted_keys(collect_keys_for_path(group_folder.parent))
)
sorted_keys = functools.partial(sorted, key=operator.itemgetter(0))
def has_secret(secret_path: Path) -> bool: def has_secret(secret_path: Path) -> bool:

View File

@@ -4,7 +4,7 @@ import json
import os import os
import shutil import shutil
import subprocess import subprocess
from collections.abc import Iterator from collections.abc import Iterable, Iterator
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -182,8 +182,9 @@ def sops_manifest(keys: list[tuple[str, KeyType]]) -> Iterator[Path]:
yield Path(manifest.name) yield Path(manifest.name)
def update_keys(secret_path: Path, keys: list[tuple[str, KeyType]]) -> list[Path]: def update_keys(secret_path: Path, keys: Iterable[tuple[str, KeyType]]) -> list[Path]:
with sops_manifest(keys) as manifest: keys_sorted = sorted(keys)
with sops_manifest(keys_sorted) as manifest:
secret_path = secret_path / "secret" secret_path = secret_path / "secret"
time_before = secret_path.stat().st_mtime time_before = secret_path.stat().st_mtime
cmd = nix_shell( cmd = nix_shell(

View File

@@ -8,7 +8,9 @@ from clan_cli.machines.machines import Machine
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def check_vars(machine: Machine, generator_name: None | str = None) -> bool: def vars_status(
machine: Machine, generator_name: None | str = None
) -> tuple[list[tuple[str, str]], list[tuple[str, str]], list[tuple[str, str]]]:
secret_vars_module = importlib.import_module(machine.secret_vars_module) secret_vars_module = importlib.import_module(machine.secret_vars_module)
secret_vars_store = secret_vars_module.SecretStore(machine=machine) secret_vars_store = secret_vars_module.SecretStore(machine=machine)
public_vars_module = importlib.import_module(machine.public_vars_module) public_vars_module = importlib.import_module(machine.public_vars_module)
@@ -16,6 +18,8 @@ def check_vars(machine: Machine, generator_name: None | str = None) -> bool:
missing_secret_vars = [] missing_secret_vars = []
missing_public_vars = [] missing_public_vars = []
# signals if a var needs to be updated (eg. needs re-encryption due to new users added)
outdated_secret_vars = []
if generator_name: if generator_name:
generators = [generator_name] generators = [generator_name]
else: else:
@@ -23,24 +27,39 @@ def check_vars(machine: Machine, generator_name: None | str = None) -> bool:
for generator_name in generators: for generator_name in generators:
shared = machine.vars_generators[generator_name]["share"] shared = machine.vars_generators[generator_name]["share"]
for name, file in machine.vars_generators[generator_name]["files"].items(): for name, file in machine.vars_generators[generator_name]["files"].items():
if file["secret"] and not secret_vars_store.exists( if file["secret"]:
generator_name, name, shared=shared if not secret_vars_store.exists(generator_name, name, shared=shared):
): log.info(
f"Secret var '{name}' for service '{generator_name}' in machine {machine.name} is missing."
)
missing_secret_vars.append((generator_name, name))
else:
needs_update, msg = secret_vars_store.needs_fix(
generator_name, name, shared=shared
)
if needs_update:
log.info(
f"Secret var '{name}' for service '{generator_name}' in machine {machine.name} needs update: {msg}"
)
outdated_secret_vars.append((generator_name, name))
elif not public_vars_store.exists(generator_name, name, shared=shared):
log.info( log.info(
f"Secret fact '{name}' for service '{generator_name}' in machine {machine.name} is missing." f"Public var '{name}' for service '{generator_name}' in machine {machine.name} is missing."
)
missing_secret_vars.append((generator_name, name))
if not file["secret"] and not public_vars_store.exists(
generator_name, name, shared=shared
):
log.info(
f"Public fact '{name}' for service '{generator_name}' in machine {machine.name} is missing."
) )
missing_public_vars.append((generator_name, name)) missing_public_vars.append((generator_name, name))
log.debug(f"missing_secret_vars: {missing_secret_vars}") log.debug(f"missing_secret_vars: {missing_secret_vars}")
log.debug(f"missing_public_vars: {missing_public_vars}") log.debug(f"missing_public_vars: {missing_public_vars}")
return not (missing_secret_vars or missing_public_vars) log.debug(f"outdated_secret_vars: {outdated_secret_vars}")
return missing_secret_vars, missing_public_vars, outdated_secret_vars
def check_vars(machine: Machine, generator_name: None | str = None) -> bool:
missing_secret_vars, missing_public_vars, outdated_secret_vars = vars_status(
machine, generator_name=generator_name
)
return not (missing_secret_vars or missing_public_vars or outdated_secret_vars)
def check_command(args: argparse.Namespace) -> None: def check_command(args: argparse.Namespace) -> None:

View File

@@ -318,11 +318,49 @@ def _check_can_migrate(
) )
def ensure_consistent_state(
machine: Machine,
generator_name: str | None,
fix: bool,
) -> None:
"""
Apply local updates to secrets like re-encrypting with missing keys
when new users were added.
"""
if generator_name is None:
generators = list(machine.vars_generators.keys())
else:
generators = [generator_name]
outdated = []
for generator_name in generators:
for name, file in machine.vars_generators[generator_name]["files"].items():
shared = machine.vars_generators[generator_name]["share"]
if file["secret"] and machine.secret_vars_store.exists(
generator_name, name
):
needs_update, msg = machine.secret_vars_store.needs_fix(
generator_name, name, shared=shared
)
if needs_update:
outdated.append((generator_name, name, msg))
if not fix and outdated:
msg = (
"The local state of some secret vars is inconsistent and needs to be updated.\n"
"Rerun 'clan vars generate' passing '--fix' to apply the necessary changes."
"Problems to fix:\n"
"\n".join(o[2] for o in outdated if o[2])
)
raise ClanError(msg)
def generate_vars_for_machine( def generate_vars_for_machine(
machine: Machine, machine: Machine,
generator_name: str | None, generator_name: str | None,
regenerate: bool, regenerate: bool,
fix: bool,
) -> bool: ) -> bool:
ensure_consistent_state(machine, generator_name, fix)
closure = get_closure(machine, generator_name, regenerate) closure = get_closure(machine, generator_name, regenerate)
if len(closure) == 0: if len(closure) == 0:
return False return False
@@ -347,13 +385,14 @@ def generate_vars(
machines: list[Machine], machines: list[Machine],
generator_name: str | None, generator_name: str | None,
regenerate: bool, regenerate: bool,
fix: bool = False,
) -> bool: ) -> bool:
was_regenerated = False was_regenerated = False
for machine in machines: for machine in machines:
errors = [] errors = []
try: try:
was_regenerated |= generate_vars_for_machine( was_regenerated |= generate_vars_for_machine(
machine, generator_name, regenerate machine, generator_name, regenerate, fix
) )
machine.flush_caches() machine.flush_caches()
except Exception as exc: except Exception as exc:
@@ -376,7 +415,7 @@ def generate_command(args: argparse.Namespace) -> None:
machines = get_all_machines(args.flake, args.option) machines = get_all_machines(args.flake, args.option)
else: else:
machines = get_selected_machines(args.flake, args.option, args.machines) machines = get_selected_machines(args.flake, args.option, args.machines)
generate_vars(machines, args.service, args.regenerate) generate_vars(machines, args.service, args.regenerate, args.fix)
def register_generate_parser(parser: argparse.ArgumentParser) -> None: def register_generate_parser(parser: argparse.ArgumentParser) -> None:
@@ -403,4 +442,12 @@ def register_generate_parser(parser: argparse.ArgumentParser) -> None:
help="whether to regenerate facts for the specified machine", help="whether to regenerate facts for the specified machine",
default=None, default=None,
) )
parser.add_argument(
"--fix",
action=argparse.BooleanOptionalAction,
help="whether to fix local state inconsistencies, for example if a secret is not encrypted with the correct keys",
default=False,
)
parser.set_defaults(func=generate_command) parser.set_defaults(func=generate_command)

View File

@@ -12,6 +12,27 @@ class SecretStoreBase(StoreBase):
def needs_upload(self) -> bool: def needs_upload(self) -> bool:
return True return True
def needs_fix(
self,
generator_name: str,
name: str,
shared: bool,
) -> tuple[bool, str | None]:
"""
Check if local state needs updating, eg. secret needs to be re-encrypted with new keys
"""
return False, None
def fix(
self,
generator_name: str,
name: str,
shared: bool,
) -> None:
"""
Update local state, eg make sure secret is encrypted with correct keys
"""
@abstractmethod @abstractmethod
def upload(self, output_dir: Path) -> None: def upload(self, output_dir: Path) -> None:
pass pass

View File

@@ -1,13 +1,23 @@
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import override
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
from clan_cli.secrets.folders import sops_machines_folder, sops_secrets_folder from clan_cli.secrets.folders import (
sops_machines_folder,
sops_secrets_folder,
sops_users_folder,
)
from clan_cli.secrets.machines import add_machine, add_secret, has_machine from clan_cli.secrets.machines import add_machine, add_secret, has_machine
from clan_cli.secrets.secrets import decrypt_secret, encrypt_secret, has_secret from clan_cli.secrets.secrets import (
from clan_cli.secrets.sops import generate_private_key collect_keys_for_path,
decrypt_secret,
encrypt_secret,
has_secret,
)
from clan_cli.secrets.sops import KeyType, generate_private_key
from . import SecretStoreBase from . import SecretStoreBase
@@ -56,6 +66,18 @@ class SecretStore(SecretStoreBase):
def store_name(self) -> str: def store_name(self) -> str:
return "sops" return "sops"
def user_has_access(
self, user: str, generator_name: str, secret_name: str, shared: bool
) -> bool:
secret_path = self.secret_path(generator_name, secret_name, shared=shared)
secret = json.loads((secret_path / "secret").read_text())
recipients = [r["recipient"] for r in (secret["sops"].get("age") or [])]
users_folder_path = sops_users_folder(self.machine.flake_dir)
user_pubkey = json.loads((users_folder_path / user / "key.json").read_text())[
"publickey"
]
return user_pubkey in recipients
def machine_has_access( def machine_has_access(
self, generator_name: str, secret_name: str, shared: bool self, generator_name: str, secret_name: str, shared: bool
) -> bool: ) -> bool:
@@ -131,3 +153,53 @@ class SecretStore(SecretStoreBase):
if not shared: if not shared:
return True return True
return self.machine_has_access(generator_name, name, shared) return self.machine_has_access(generator_name, name, shared)
def collect_keys_for_secret(self, path: Path) -> set[tuple[str, KeyType]]:
from clan_cli.secrets.secrets import (
collect_keys_for_path,
collect_keys_for_type,
)
keys = collect_keys_for_path(path)
for group in self.machine.deployment["sops"]["defaultGroups"]:
keys.update(
collect_keys_for_type(
self.machine.flake_dir / "sops" / "groups" / group / "machines"
)
)
keys.update(
collect_keys_for_type(
self.machine.flake_dir / "sops" / "groups" / group / "users"
)
)
return keys
@override
def needs_fix(
self, generator_name: str, name: str, shared: bool
) -> tuple[bool, str | None]:
secret_path = self.secret_path(generator_name, name, shared)
recipients_ = json.loads((secret_path / "secret").read_text())["sops"]["age"]
current_recipients = {r["recipient"] for r in recipients_}
wanted_recipients = {
key[0] for key in self.collect_keys_for_secret(secret_path)
}
needs_update = current_recipients != wanted_recipients
recipients_to_add = wanted_recipients - current_recipients
var_id = f"{generator_name}/{name}"
msg = (
f"One or more recipient keys were added to secret{' shared' if shared else ''} var '{var_id}', but it was never re-encrypted. "
f"This could have been a malicious actor trying to add their keys, please investigate. "
f"Added keys: {', '.join(recipients_to_add)}"
)
return needs_update, msg
@override
def fix(self, generator_name: str, name: str, shared: bool) -> None:
from clan_cli.secrets.secrets import update_keys
secret_path = self.secret_path(generator_name, name, shared)
update_keys(
secret_path,
collect_keys_for_path(secret_path),
)

View File

@@ -191,6 +191,31 @@ def test_generate_secret_var_sops_with_default_group(
) )
assert sops_store.exists("my_generator", "my_secret") assert sops_store.exists("my_generator", "my_secret")
assert sops_store.get("my_generator", "my_secret").decode() == "hello\n" assert sops_store.get("my_generator", "my_secret").decode() == "hello\n"
# add another user and check if secret gets re-encrypted
from clan_cli.secrets.sops import generate_private_key
_, pubkey_uschi = generate_private_key()
cli.run(
[
"secrets",
"users",
"add",
"--flake",
str(flake.path),
"uschi",
pubkey_uschi,
]
)
cli.run(["secrets", "groups", "add-user", "my_group", "uschi"])
with pytest.raises(ClanError):
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
# apply fix
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine", "--fix"])
# check if new user can access the secret
monkeypatch.setenv("USER", "uschi")
assert sops_store.user_has_access(
"uschi", "my_generator", "my_secret", shared=False
)
@pytest.mark.impure @pytest.mark.impure
@@ -746,6 +771,7 @@ def test_stdout_of_generate(
Machine(name="my_machine", flake=FlakeId(str(flake.path))), Machine(name="my_machine", flake=FlakeId(str(flake.path))),
"my_generator", "my_generator",
regenerate=False, regenerate=False,
fix=False,
) )
assert "Updated var my_generator/my_value" in output.out assert "Updated var my_generator/my_value" in output.out
@@ -757,6 +783,7 @@ def test_stdout_of_generate(
Machine(name="my_machine", flake=FlakeId(str(flake.path))), Machine(name="my_machine", flake=FlakeId(str(flake.path))),
"my_generator", "my_generator",
regenerate=True, regenerate=True,
fix=False,
) )
assert "Updated var my_generator/my_value" in output.out assert "Updated var my_generator/my_value" in output.out
assert "old: world" in output.out assert "old: world" in output.out
@@ -767,6 +794,7 @@ def test_stdout_of_generate(
Machine(name="my_machine", flake=FlakeId(str(flake.path))), Machine(name="my_machine", flake=FlakeId(str(flake.path))),
"my_generator", "my_generator",
regenerate=True, regenerate=True,
fix=False,
) )
assert "Updated" not in output.out assert "Updated" not in output.out
assert "hello" in output.out assert "hello" in output.out
@@ -775,6 +803,7 @@ def test_stdout_of_generate(
Machine(name="my_machine", flake=FlakeId(str(flake.path))), Machine(name="my_machine", flake=FlakeId(str(flake.path))),
"my_secret_generator", "my_secret_generator",
regenerate=False, regenerate=False,
fix=False,
) )
assert "Updated secret var my_secret_generator/my_secret" in output.out assert "Updated secret var my_secret_generator/my_secret" in output.out
assert "hello" not in output.out assert "hello" not in output.out
@@ -789,6 +818,7 @@ def test_stdout_of_generate(
Machine(name="my_machine", flake=FlakeId(str(flake.path))), Machine(name="my_machine", flake=FlakeId(str(flake.path))),
"my_secret_generator", "my_secret_generator",
regenerate=True, regenerate=True,
fix=False,
) )
assert "Updated secret var my_secret_generator/my_secret" in output.out assert "Updated secret var my_secret_generator/my_secret" in output.out
assert "world" not in output.out assert "world" not in output.out
@@ -891,6 +921,7 @@ def test_fails_when_files_are_left_from_other_backend(
Machine(name="my_machine", flake=FlakeId(str(flake.path))), Machine(name="my_machine", flake=FlakeId(str(flake.path))),
generator, generator,
regenerate=False, regenerate=False,
fix=False,
) )
my_secret_generator["files"]["my_secret"]["secret"] = False my_secret_generator["files"]["my_secret"]["secret"] = False
my_value_generator["files"]["my_value"]["secret"] = True my_value_generator["files"]["my_value"]["secret"] = True
@@ -902,6 +933,7 @@ def test_fails_when_files_are_left_from_other_backend(
Machine(name="my_machine", flake=FlakeId(str(flake.path))), Machine(name="my_machine", flake=FlakeId(str(flake.path))),
generator, generator,
regenerate=False, regenerate=False,
fix=False,
) )