ruff: apply automatic fixes

This commit is contained in:
Jörg Thalheim
2025-08-20 13:52:45 +02:00
parent 798d445f3e
commit ea2d6aab65
217 changed files with 2283 additions and 1739 deletions

View File

@@ -22,12 +22,16 @@ def create_command(args: argparse.Namespace) -> None:
def register_create_parser(parser: argparse.ArgumentParser) -> None:
machines_parser = parser.add_argument(
"machine", type=str, help="machine in the flake to create backups of"
"machine",
type=str,
help="machine in the flake to create backups of",
)
add_dynamic_completer(machines_parser, complete_machines)
provider_action = parser.add_argument(
"--provider", type=str, help="backup provider to use"
"--provider",
type=str,
help="backup provider to use",
)
add_dynamic_completer(provider_action, complete_backup_providers_for_machine)
parser.set_defaults(func=create_command)

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_create_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -21,11 +21,15 @@ def list_command(args: argparse.Namespace) -> None:
def register_list_parser(parser: argparse.ArgumentParser) -> None:
machines_parser = parser.add_argument(
"machine", type=str, help="machine in the flake to show backups of"
"machine",
type=str,
help="machine in the flake to show backups of",
)
add_dynamic_completer(machines_parser, complete_machines)
provider_action = parser.add_argument(
"--provider", type=str, help="backup provider to filter by"
"--provider",
type=str,
help="backup provider to filter by",
)
add_dynamic_completer(provider_action, complete_backup_providers_for_machine)
parser.set_defaults(func=list_command)

View File

@@ -24,11 +24,15 @@ def restore_command(args: argparse.Namespace) -> None:
def register_restore_parser(parser: argparse.ArgumentParser) -> None:
machine_action = parser.add_argument(
"machine", type=str, help="machine in the flake to create backups of"
"machine",
type=str,
help="machine in the flake to create backups of",
)
add_dynamic_completer(machine_action, complete_machines)
provider_action = parser.add_argument(
"provider", type=str, help="backup provider to use"
"provider",
type=str,
help="backup provider to use",
)
add_dynamic_completer(provider_action, complete_backup_providers_for_machine)
parser.add_argument("name", type=str, help="Name of the backup to restore")

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_restore_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -67,7 +67,7 @@ def register_create_parser(parser: argparse.ArgumentParser) -> None:
setup_git=not args.no_git,
src_flake=args.flake,
update_clan=not args.no_update,
)
),
)
create_secrets_user_auto(
flake_dir=Path(args.name).resolve(),

View File

@@ -74,8 +74,8 @@ def inspect_flake(flake_url: str | Path, machine_name: str) -> FlakeConfig:
# Get the Clan name
cmd = nix_eval(
[
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.name'
]
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.name',
],
)
res = run_cmd(cmd)
clan_name = res.strip('"')
@@ -83,8 +83,8 @@ def inspect_flake(flake_url: str | Path, machine_name: str) -> FlakeConfig:
# Get the clan icon path
cmd = nix_eval(
[
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.icon'
]
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.icon',
],
)
res = run_cmd(cmd)
@@ -96,7 +96,7 @@ def inspect_flake(flake_url: str | Path, machine_name: str) -> FlakeConfig:
cmd = nix_build(
[
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.icon'
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.icon',
],
machine_gcroot(flake_url=str(flake_url)) / "icon",
)
@@ -129,7 +129,8 @@ def inspect_command(args: argparse.Namespace) -> None:
flake=args.flake or Flake(str(Path.cwd())),
)
res = inspect_flake(
flake_url=str(inspect_options.flake), machine_name=inspect_options.machine
flake_url=str(inspect_options.flake),
machine_name=inspect_options.machine,
)
print("Clan name:", res.clan_name)
print("Icon:", res.icon)

View File

@@ -10,7 +10,8 @@ from clan_cli.tests.stdout import CaptureOutput
@pytest.mark.with_core
def test_clan_show(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput
test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None:
with capture_output as output:
cli.run(["show", "--flake", str(test_flake_with_core.path)])
@@ -20,7 +21,9 @@ def test_clan_show(
def test_clan_show_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capture_output: CaptureOutput
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
capture_output: CaptureOutput,
) -> None:
monkeypatch.chdir(tmp_path)
@@ -28,8 +31,8 @@ def test_clan_show_no_flake(
cli.run(["show"])
assert "No clan flake found in the current directory or its parents" in str(
exc_info.value
exc_info.value,
)
assert "Use the --flake flag to specify a clan flake path or URL" in str(
exc_info.value
exc_info.value,
)

View File

@@ -52,8 +52,7 @@ def create_flake_from_args(args: argparse.Namespace) -> Flake:
def add_common_flags(parser: argparse.ArgumentParser) -> None:
def argument_exists(parser: argparse.ArgumentParser, arg: str) -> bool:
"""
Check if an argparse argument already exists.
"""Check if an argparse argument already exists.
This is needed because the aliases subcommand doesn't *really*
create an alias - it duplicates the actual parser in the tree
making duplication inevitable while naively traversing.
@@ -410,7 +409,9 @@ For more detailed information, visit: {help_hyperlink("deploy", "https://docs.cl
machines.register_parser(parser_machine)
parser_vms = subparsers.add_parser(
"vms", help="Manage virtual machines", description="Manage virtual machines"
"vms",
help="Manage virtual machines",
description="Manage virtual machines",
)
vms.register_parser(parser_vms)

View File

@@ -38,11 +38,11 @@ def clan_dir(flake: str | None) -> str | None:
def complete_machines(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for machine names configured in the clan.
"""
"""Provides completion functionality for machine names configured in the clan."""
machines: list[str] = []
def run_cmd() -> None:
@@ -72,11 +72,11 @@ def complete_machines(
def complete_services_for_machine(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for machine facts generation services.
"""
"""Provides completion functionality for machine facts generation services."""
services: list[str] = []
# TODO: consolidate, if multiple machines are used
machines: list[str] = parsed_args.machines
@@ -98,7 +98,7 @@ def complete_services_for_machine(
"builtins.attrNames",
],
),
).stdout.strip()
).stdout.strip(),
)
services.extend(services_result)
@@ -117,11 +117,11 @@ def complete_services_for_machine(
def complete_backup_providers_for_machine(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for machine backup providers.
"""
"""Provides completion functionality for machine backup providers."""
providers: list[str] = []
machine: str = parsed_args.machine
@@ -142,7 +142,7 @@ def complete_backup_providers_for_machine(
"builtins.attrNames",
],
),
).stdout.strip()
).stdout.strip(),
)
providers.extend(providers_result)
@@ -161,11 +161,11 @@ def complete_backup_providers_for_machine(
def complete_state_services_for_machine(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for machine state providers.
"""
"""Provides completion functionality for machine state providers."""
providers: list[str] = []
machine: str = parsed_args.machine
@@ -186,7 +186,7 @@ def complete_state_services_for_machine(
"builtins.attrNames",
],
),
).stdout.strip()
).stdout.strip(),
)
providers.extend(providers_result)
@@ -205,11 +205,11 @@ def complete_state_services_for_machine(
def complete_secrets(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for clan secrets
"""
"""Provides completion functionality for clan secrets"""
from clan_lib.flake.flake import Flake
from .secrets.secrets import list_secrets
@@ -228,11 +228,11 @@ def complete_secrets(
def complete_users(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for clan users
"""
"""Provides completion functionality for clan users"""
from pathlib import Path
from .secrets.users import list_users
@@ -251,11 +251,11 @@ def complete_users(
def complete_groups(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for clan groups
"""
"""Provides completion functionality for clan groups"""
from pathlib import Path
from .secrets.groups import list_groups
@@ -275,12 +275,11 @@ def complete_groups(
def complete_templates_disko(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for disko templates
"""
"""Provides completion functionality for disko templates"""
from clan_lib.templates import list_templates
flake = (
@@ -300,12 +299,11 @@ def complete_templates_disko(
def complete_templates_clan(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for clan templates
"""
"""Provides completion functionality for clan templates"""
from clan_lib.templates import list_templates
flake = (
@@ -325,10 +323,11 @@ def complete_templates_clan(
def complete_vars_for_machine(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for variable names for a specific machine.
"""Provides completion functionality for variable names for a specific machine.
Only completes vars that already exist in the vars directory on disk.
This is fast as it only scans the filesystem without any evaluation.
"""
@@ -368,11 +367,11 @@ def complete_vars_for_machine(
def complete_target_host(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for target_host for a specific machine
"""
"""Provides completion functionality for target_host for a specific machine"""
target_hosts: list[str] = []
machine: str = parsed_args.machine
@@ -391,7 +390,7 @@ def complete_target_host(
f"{flake}#nixosConfigurations.{machine}.config.clan.core.networking.targetHost",
],
),
).stdout.strip()
).stdout.strip(),
)
target_hosts.append(target_host_result)
@@ -410,11 +409,11 @@ def complete_target_host(
def complete_tags(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any
prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]:
"""
Provides completion functionality for tags inside the inventory
"""
"""Provides completion functionality for tags inside the inventory"""
tags: list[str] = []
threads = []
@@ -483,8 +482,7 @@ def add_dynamic_completer(
action: argparse.Action,
completer: Callable[..., Iterable[str]],
) -> None:
"""
Add a completion function to an argparse action, this will only be added,
"""Add a completion function to an argparse action, this will only be added,
if the argcomplete module is loaded.
"""
if argcomplete:

View File

@@ -21,14 +21,14 @@ def check_secrets(machine: Machine, service: None | str = None) -> bool:
secret_name = secret_fact["name"]
if not machine.secret_facts_store.exists(service, secret_name):
machine.info(
f"Secret fact '{secret_fact}' for service '{service}' is missing."
f"Secret fact '{secret_fact}' for service '{service}' is missing.",
)
missing_secret_facts.append((service, secret_name))
for public_fact in machine.facts_data[service]["public"]:
if not machine.public_facts_store.exists(service, public_fact):
machine.info(
f"Public fact '{public_fact}' for service '{service}' is missing."
f"Public fact '{public_fact}' for service '{service}' is missing.",
)
missing_public_facts.append((service, public_fact))

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_check_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -29,9 +29,7 @@ log = logging.getLogger(__name__)
def read_multiline_input(prompt: str = "Finish with Ctrl-D") -> str:
"""
Read multi-line input from stdin.
"""
"""Read multi-line input from stdin."""
print(prompt, flush=True)
proc = run(["cat"], RunOpts(check=False))
log.info("Input received. Processing...")
@@ -63,7 +61,7 @@ def bubblewrap_cmd(generator: str, facts_dir: Path, secrets_dir: Path) -> list[s
"--uid", "1000",
"--gid", "1000",
"--",
"bash", "-c", generator
"bash", "-c", generator,
],
)
# fmt: on
@@ -102,7 +100,8 @@ def generate_service_facts(
generator = machine.facts_data[service]["generator"]["finalScript"]
if machine.facts_data[service]["generator"]["prompt"]:
prompt_value = prompt(
service, machine.facts_data[service]["generator"]["prompt"]
service,
machine.facts_data[service]["generator"]["prompt"],
)
env["prompt_value"] = prompt_value
from clan_lib import bwrap
@@ -126,7 +125,10 @@ def generate_service_facts(
msg += generator
raise ClanError(msg)
secret_path = secret_facts_store.set(
service, secret_name, secret_file.read_bytes(), groups
service,
secret_name,
secret_file.read_bytes(),
groups,
)
if secret_path:
files_to_commit.append(secret_path)
@@ -206,7 +208,11 @@ def generate_facts(
errors = 0
try:
was_regenerated |= _generate_facts_for_machine(
machine, service, regenerate, tmpdir, prompt
machine,
service,
regenerate,
tmpdir,
prompt,
)
except (OSError, ClanError) as e:
machine.error(f"Failed to generate facts: {e}")
@@ -231,7 +237,7 @@ def generate_command(args: argparse.Namespace) -> None:
filter(
lambda m: m.name in args.machines,
machines,
)
),
)
generate_facts(machines, args.service, args.regenerate)

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_generate_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
import clan_lib.machines.machines as machines
from clan_lib.machines import machines
class FactStoreBase(ABC):

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
import clan_lib.machines.machines as machines
from clan_lib.machines import machines
from clan_lib.ssh.host import Host
@@ -14,7 +14,11 @@ class SecretStoreBase(ABC):
@abstractmethod
def set(
self, service: str, name: str, value: bytes, groups: list[str]
self,
service: str,
name: str,
value: bytes,
groups: list[str],
) -> Path | None:
pass

View File

@@ -16,7 +16,11 @@ class SecretStore(SecretStoreBase):
self.machine = machine
def set(
self, service: str, name: str, value: bytes, groups: list[str]
self,
service: str,
name: str,
value: bytes,
groups: list[str],
) -> Path | None:
subprocess.run(
nix_shell(
@@ -40,14 +44,16 @@ class SecretStore(SecretStoreBase):
def exists(self, service: str, name: str) -> bool:
password_store = os.environ.get(
"PASSWORD_STORE_DIR", f"{os.environ['HOME']}/.password-store"
"PASSWORD_STORE_DIR",
f"{os.environ['HOME']}/.password-store",
)
secret_path = Path(password_store) / f"machines/{self.machine.name}/{name}.gpg"
return secret_path.exists()
def generate_hash(self) -> bytes:
password_store = os.environ.get(
"PASSWORD_STORE_DIR", f"{os.environ['HOME']}/.password-store"
"PASSWORD_STORE_DIR",
f"{os.environ['HOME']}/.password-store",
)
hashes = []
hashes.append(
@@ -66,7 +72,7 @@ class SecretStore(SecretStoreBase):
),
stdout=subprocess.PIPE,
check=False,
).stdout.strip()
).stdout.strip(),
)
for symlink in Path(password_store).glob(f"machines/{self.machine.name}/**/*"):
if symlink.is_symlink():
@@ -86,7 +92,7 @@ class SecretStore(SecretStoreBase):
),
stdout=subprocess.PIPE,
check=False,
).stdout.strip()
).stdout.strip(),
)
# we sort the hashes to make sure that the order is always the same

View File

@@ -37,7 +37,11 @@ class SecretStore(SecretStoreBase):
add_machine(self.machine.flake_dir, self.machine.name, pub_key, False)
def set(
self, service: str, name: str, value: bytes, groups: list[str]
self,
service: str,
name: str,
value: bytes,
groups: list[str],
) -> Path | None:
path = (
sops_secrets_folder(self.machine.flake_dir) / f"{self.machine.name}-{name}"

View File

@@ -15,7 +15,11 @@ class SecretStore(SecretStoreBase):
self.dir.mkdir(parents=True, exist_ok=True)
def set(
self, service: str, name: str, value: bytes, groups: list[str]
self,
service: str,
name: str,
value: bytes,
groups: list[str],
) -> Path | None:
secret_file = self.dir / service / name
secret_file.parent.mkdir(parents=True, exist_ok=True)

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_upload_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -21,6 +21,7 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
register_flash_write_parser(write_parser)
list_parser = subparser.add_parser(
"list", help="List possible keymaps or languages"
"list",
help="List possible keymaps or languages",
)
register_flash_list_parser(list_parser)

View File

@@ -121,7 +121,7 @@ def register_flash_write_parser(parser: argparse.ArgumentParser) -> None:
Format will format the disk before installing.
Mount will mount the disk before installing.
Mount is useful for updating an existing system without losing data.
"""
""",
)
parser.add_argument(
"--mode",
@@ -166,7 +166,7 @@ def register_flash_write_parser(parser: argparse.ArgumentParser) -> None:
Write EFI boot entries to the NVRAM of the system for the installed system.
Specify this option if you plan to boot from this disk on the current machine,
but not if you plan to move the disk to another machine.
"""
""",
).strip(),
default=False,
action="store_true",

View File

@@ -8,7 +8,8 @@ from clan_cli.tests.stdout import CaptureOutput
@pytest.mark.with_core
def test_flash_list_languages(
temporary_home: Path, capture_output: CaptureOutput
temporary_home: Path,
capture_output: CaptureOutput,
) -> None:
with capture_output as output:
cli.run(["flash", "list", "languages"])
@@ -20,7 +21,8 @@ def test_flash_list_languages(
@pytest.mark.with_core
def test_flash_list_keymaps(
temporary_home: Path, capture_output: CaptureOutput
temporary_home: Path,
capture_output: CaptureOutput,
) -> None:
with capture_output as output:
cli.run(["flash", "list", "keymaps"])

View File

@@ -1,7 +1,6 @@
# Implementation of OSC8
def hyperlink(text: str, url: str) -> str:
"""
Generate OSC8 escape sequence for hyperlinks.
"""Generate OSC8 escape sequence for hyperlinks.
Args:
url (str): The URL to link to.
@@ -9,15 +8,14 @@ def hyperlink(text: str, url: str) -> str:
Returns:
str: The formatted string with an embedded hyperlink.
"""
esc = "\033"
return f"{esc}]8;;{url}{esc}\\{text}{esc}]8;;{esc}\\"
def hyperlink_same_text_and_url(url: str) -> str:
"""
Keep the description and the link the same to support legacy terminals.
"""
"""Keep the description and the link the same to support legacy terminals."""
return hyperlink(url, url)
@@ -34,9 +32,7 @@ def help_hyperlink(description: str, url: str) -> str:
def docs_hyperlink(description: str, url: str) -> str:
"""
Returns a markdown hyperlink
"""
"""Returns a markdown hyperlink"""
url = url.replace("https://docs.clan.lol", "../..")
url = url.replace("index.html", "index")
url += ".md"

View File

@@ -32,8 +32,7 @@ def create_machine(
opts: CreateOptions,
commit: bool = True,
) -> None:
"""
Create a new machine in the clan directory.
"""Create a new machine in the clan directory.
This function will create a new machine based on a template.
@@ -41,7 +40,6 @@ def create_machine(
:param commit: Whether to commit the changes to the git repository.
:param _persist: Temporary workaround for 'morph'. Whether to persist the changes to the inventory store.
"""
if not opts.clan_dir.is_local:
msg = f"Clan {opts.clan_dir} is not a local clan."
description = "Import machine only works on local clans"

View File

@@ -33,13 +33,15 @@ def update_hardware_config_command(args: argparse.Namespace) -> None:
if args.target_host:
target_host = Remote.from_ssh_uri(
machine_name=machine.name, address=args.target_host
machine_name=machine.name,
address=args.target_host,
)
else:
target_host = machine.target_host()
target_host = target_host.override(
host_key_check=args.host_key_check, private_key=args.identity_file
host_key_check=args.host_key_check,
private_key=args.identity_file,
)
run_machine_hardware_info(opts, target_host)

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_create_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -34,7 +34,8 @@ def install_command(args: argparse.Namespace) -> None:
if args.target_host:
# TODO add network support here with either --network or some url magic
remote = Remote.from_ssh_uri(
machine_name=args.machine, address=args.target_host
machine_name=args.machine,
address=args.target_host,
)
elif args.png:
data = read_qr_image(Path(args.png))
@@ -73,7 +74,7 @@ def install_command(args: argparse.Namespace) -> None:
if ask == "n" or ask == "":
return None
print(
f"Invalid input '{ask}'. Please enter 'y' for yes or 'n' for no."
f"Invalid input '{ask}'. Please enter 'y' for yes or 'n' for no.",
)
if args.identity_file:

View File

@@ -13,7 +13,8 @@ def list_command(args: argparse.Namespace) -> None:
flake = require_flake(args.flake)
for name in list_machines(
flake, opts=ListOptions(filter=MachineFilter(tags=args.tags))
flake,
opts=ListOptions(filter=MachineFilter(tags=args.tags)),
):
print(name)

View File

@@ -43,7 +43,7 @@ def list_basic(
description = "Backup server";
};
};
}"""
}""",
},
],
indirect=True,
@@ -62,7 +62,7 @@ def list_with_tags_single_tag(
str(test_flake_with_core.path),
"--tags",
"production",
]
],
)
assert "web-server" in output.out
@@ -94,7 +94,7 @@ def list_with_tags_single_tag(
description = "Backup server";
};
};
}"""
}""",
},
],
indirect=True,
@@ -114,7 +114,7 @@ def list_with_tags_multiple_tags_intersection(
"--tags",
"web",
"production",
]
],
)
# Should only include machines that have BOTH tags (intersection)
@@ -139,7 +139,7 @@ def test_machines_list_with_tags_no_matches(
str(test_flake_with_core.path),
"--tags",
"nonexistent",
]
],
)
assert output.out.strip() == ""
@@ -162,7 +162,7 @@ def test_machines_list_with_tags_no_matches(
};
server4 = { };
};
}"""
}""",
},
],
indirect=True,
@@ -180,7 +180,7 @@ def list_with_tags_various_scenarios(
str(test_flake_with_core.path),
"--tags",
"web",
]
],
)
assert "server1" in output.out
@@ -197,7 +197,7 @@ def list_with_tags_various_scenarios(
str(test_flake_with_core.path),
"--tags",
"database",
]
],
)
assert "server2" in output.out
@@ -216,7 +216,7 @@ def list_with_tags_various_scenarios(
"--tags",
"web",
"database",
]
],
)
assert "server3" in output.out
@@ -239,7 +239,7 @@ def created_machine_and_tags(
"--tags",
"test",
"server",
]
],
)
with capture_output as output:
@@ -258,7 +258,7 @@ def created_machine_and_tags(
str(test_flake_with_core.path),
"--tags",
"test",
]
],
)
assert "test-machine" in output.out
@@ -274,7 +274,7 @@ def created_machine_and_tags(
str(test_flake_with_core.path),
"--tags",
"server",
]
],
)
assert "test-machine" in output.out
@@ -291,7 +291,7 @@ def created_machine_and_tags(
"--tags",
"test",
"server",
]
],
)
assert "test-machine" in output.out
@@ -310,7 +310,7 @@ def created_machine_and_tags(
};
machine-without-tags = { };
};
}"""
}""",
},
],
indirect=True,
@@ -334,7 +334,7 @@ def list_mixed_tagged_untagged(
str(test_flake_with_core.path),
"--tags",
"tag1",
]
],
)
assert "machine-with-tags" in output.out
@@ -349,7 +349,7 @@ def list_mixed_tagged_untagged(
str(test_flake_with_core.path),
"--tags",
"nonexistent",
]
],
)
assert "machine-with-tags" not in output.out
@@ -358,7 +358,8 @@ def list_mixed_tagged_untagged(
def test_machines_list_require_flake_error(
temporary_home: Path, monkeypatch: pytest.MonkeyPatch
temporary_home: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that machines list command fails when flake is required but not provided."""
monkeypatch.chdir(temporary_home)

View File

@@ -15,7 +15,7 @@ from clan_cli.tests.fixtures_flakes import FlakeForTest
machines.jon1 = { };
machines.jon2 = { machineClass = "nixos"; };
machines.sara = { machineClass = "darwin"; };
}"""
}""",
},
],
# Important!
@@ -27,8 +27,7 @@ from clan_cli.tests.fixtures_flakes import FlakeForTest
def test_inventory_machine_detect_class(
test_flake_with_core: FlakeForTest,
) -> None:
"""
Testing different inventory deserializations
"""Testing different inventory deserializations
Inventory should always be deserializable to a dict
"""
machine_jon1 = Machine(

View File

@@ -87,7 +87,8 @@ def get_machines_for_update(
) -> list[Machine]:
all_machines = list_machines(flake)
machines_with_tags = list_machines(
flake, ListOptions(filter=MachineFilter(tags=filter_tags))
flake,
ListOptions(filter=MachineFilter(tags=filter_tags)),
)
if filter_tags and not machines_with_tags:
@@ -101,7 +102,7 @@ def get_machines_for_update(
filter(
requires_explicit_update,
instantiate_inventory_to_machines(flake, machines_with_tags).values(),
)
),
)
# all machines that are in the clan but not included in the update list
machine_names_to_update = [m.name for m in machines_to_update]
@@ -131,7 +132,7 @@ def get_machines_for_update(
raise ClanError(msg)
machines_to_update.append(
Machine.from_inventory(name, flake, inventory_machine)
Machine.from_inventory(name, flake, inventory_machine),
)
return machines_to_update
@@ -163,7 +164,7 @@ def update_command(args: argparse.Namespace) -> None:
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.core.vars.settings.secretModule",
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.deployment.requireExplicitUpdate",
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.system.clan.deployment.nixosMobileWorkaround",
]
],
)
host_key_check = args.host_key_check

View File

@@ -17,12 +17,12 @@ from clan_cli.tests.helpers import cli
"inventory_expr": r"""{
machines.jon = { tags = [ "foo" "bar" ]; };
machines.sara = { tags = [ "foo" "baz" ]; };
}"""
}""",
},
["jon"], # explizit names
[], # filter tags
["jon"], # expected
)
),
],
# Important!
# tells pytest to pass these values to the fixture
@@ -55,12 +55,12 @@ def test_get_machines_for_update_single_name(
"inventory_expr": r"""{
machines.jon = { tags = [ "foo" "bar" ]; };
machines.sara = { tags = [ "foo" "baz" ]; };
}"""
}""",
},
[], # explizit names
["foo"], # filter tags
["jon", "sara"], # expected
)
),
],
# Important!
# tells pytest to pass these values to the fixture
@@ -93,12 +93,12 @@ def test_get_machines_for_update_tags(
"inventory_expr": r"""{
machines.jon = { tags = [ "foo" "bar" ]; };
machines.sara = { tags = [ "foo" "baz" ]; };
}"""
}""",
},
["sara"], # explizit names
["foo"], # filter tags
["sara"], # expected
)
),
],
# Important!
# tells pytest to pass these values to the fixture
@@ -131,7 +131,7 @@ def test_get_machines_for_update_tags_and_name(
"inventory_expr": r"""{
machines.jon = { tags = [ "foo" "bar" ]; };
machines.sara = { tags = [ "foo" "baz" ]; };
}"""
}""",
},
[], # no explizit names
[], # no filter tags
@@ -162,7 +162,8 @@ def test_get_machines_for_update_implicit_all(
def test_update_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -19,7 +19,7 @@ def list_command(args: argparse.Namespace) -> None:
col_network = max(12, max(len(name) for name in networks))
col_priority = 8
col_module = max(
10, max(len(net.module_name.split(".")[-1]) for net in networks.values())
10, max(len(net.module_name.split(".")[-1]) for net in networks.values()),
)
col_running = 8
@@ -30,7 +30,8 @@ def list_command(args: argparse.Namespace) -> None:
# Print network entries
for network_name, network in sorted(
networks.items(), key=lambda network: -network[1].priority
networks.items(),
key=lambda network: -network[1].priority,
):
# Extract simple module name from full module path
module_name = network.module_name.split(".")[-1]
@@ -56,7 +57,7 @@ def list_command(args: argparse.Namespace) -> None:
running_status = "Error"
print(
f"{network_name:<{col_network}} {network.priority:<{col_priority}} {module_name:<{col_module}} {running_status:<{col_running}} {peers_str}"
f"{network_name:<{col_network}} {network.priority:<{col_priority}} {module_name:<{col_module}} {running_status:<{col_running}} {peers_str}",
)

View File

@@ -95,8 +95,7 @@ PROFS = ProfilerStore()
def profile(func: Callable) -> Callable:
"""
A decorator that profiles the decorated function, printing out the profiling
"""A decorator that profiles the decorated function, printing out the profiling
results with paths trimmed to three directories deep.
"""

View File

@@ -39,7 +39,8 @@ class QgaSession:
def run_nonblocking(self, cmd: list[str]) -> int:
result_pid = self.client.cmd(
"guest-exec", {"path": cmd[0], "arg": cmd[1:], "capture-output": True}
"guest-exec",
{"path": cmd[0], "arg": cmd[1:], "capture-output": True},
)
if result_pid is None:
msg = "Could not get PID from QGA"

View File

@@ -20,32 +20,23 @@ from clan_lib.errors import ClanError
class QMPError(Exception):
"""
QMP base exception
"""
"""QMP base exception"""
class QMPConnectError(QMPError):
"""
QMP connection exception
"""
"""QMP connection exception"""
class QMPCapabilitiesError(QMPError):
"""
QMP negotiate capabilities exception
"""
"""QMP negotiate capabilities exception"""
class QMPTimeoutError(QMPError):
"""
QMP timeout exception
"""
"""QMP timeout exception"""
class QEMUMonitorProtocol:
"""
Provide an API to connect to QEMU via QEMU Monitor Protocol (QMP) and then
"""Provide an API to connect to QEMU via QEMU Monitor Protocol (QMP) and then
allow to handle commands and events.
"""
@@ -58,8 +49,7 @@ class QEMUMonitorProtocol:
server: bool = False,
nickname: str | None = None,
) -> None:
"""
Create a QEMUMonitorProtocol class.
"""Create a QEMUMonitorProtocol class.
@param address: QEMU address, can be either a unix socket path (string)
or a tuple in the form ( address, port ) for a TCP
@@ -109,8 +99,7 @@ class QEMUMonitorProtocol:
return resp
def __get_events(self, wait: bool | float = False) -> None:
"""
Check for new events in the stream and cache them in __events.
"""Check for new events in the stream and cache them in __events.
@param wait (bool): block until an event is available.
@param wait (float): If wait is a float, treat it as a timeout value.
@@ -120,7 +109,6 @@ class QEMUMonitorProtocol:
@raise QMPConnectError: If wait is True but no events could be
retrieved or if some other error occurred.
"""
# Check for new events regardless and pull them into the cache:
self.__sock.setblocking(0)
try:
@@ -163,8 +151,7 @@ class QEMUMonitorProtocol:
self.close()
def connect(self, negotiate: bool = True) -> dict[str, Any] | None:
"""
Connect to the QMP Monitor and perform capabilities negotiation.
"""Connect to the QMP Monitor and perform capabilities negotiation.
@return QMP greeting dict, or None if negotiate is false
@raise OSError on socket connection errors
@@ -178,8 +165,7 @@ class QEMUMonitorProtocol:
return None
def accept(self, timeout: float | None = 15.0) -> dict[str, Any]:
"""
Await connection from QMP Monitor and perform capabilities negotiation.
"""Await connection from QMP Monitor and perform capabilities negotiation.
@param timeout: timeout in seconds (nonnegative float number, or
None). The value passed will set the behavior of the
@@ -199,8 +185,7 @@ class QEMUMonitorProtocol:
return self.__negotiate_capabilities()
def cmd_obj(self, qmp_cmd: dict[str, Any]) -> dict[str, Any] | None:
"""
Send a QMP command to the QMP Monitor.
"""Send a QMP command to the QMP Monitor.
@param qmp_cmd: QMP command to be sent as a Python dict
@return QMP response as a Python dict or None if the connection has
@@ -223,8 +208,7 @@ class QEMUMonitorProtocol:
args: dict[str, Any] | None = None,
cmd_id: dict[str, Any] | list[Any] | str | int | None = None,
) -> dict[str, Any] | None:
"""
Build a QMP command and send it to the QMP Monitor.
"""Build a QMP command and send it to the QMP Monitor.
@param name: command name (string)
@param args: command arguments (dict)
@@ -238,17 +222,14 @@ class QEMUMonitorProtocol:
return self.cmd_obj(qmp_cmd)
def command(self, cmd: str, **kwds: Any) -> Any:
"""
Build and send a QMP command to the monitor, report errors if any
"""
"""Build and send a QMP command to the monitor, report errors if any"""
ret = self.cmd(cmd, kwds)
if "error" in ret:
raise ClanError(ret["error"]["desc"])
return ret["return"]
def pull_event(self, wait: bool | float = False) -> dict[str, Any] | None:
"""
Pulls a single event.
"""Pulls a single event.
@param wait (bool): block until an event is available.
@param wait (float): If wait is a float, treat it as a timeout value.
@@ -267,8 +248,7 @@ class QEMUMonitorProtocol:
return None
def get_events(self, wait: bool | float = False) -> list[dict[str, Any]]:
"""
Get a list of available QMP events.
"""Get a list of available QMP events.
@param wait (bool): block until an event is available.
@param wait (float): If wait is a float, treat it as a timeout value.
@@ -284,23 +264,18 @@ class QEMUMonitorProtocol:
return self.__events
def clear_events(self) -> None:
"""
Clear current list of pending events.
"""
"""Clear current list of pending events."""
self.__events = []
def close(self) -> None:
"""
Close the socket and socket file.
"""
"""Close the socket and socket file."""
if self.__sock:
self.__sock.close()
if self.__sockfile:
self.__sockfile.close()
def settimeout(self, timeout: float | None) -> None:
"""
Set the socket timeout.
"""Set the socket timeout.
@param timeout (float): timeout in seconds, or None.
@note This is a wrap around socket.settimeout
@@ -308,16 +283,14 @@ class QEMUMonitorProtocol:
self.__sock.settimeout(timeout)
def get_sock_fd(self) -> int:
"""
Get the socket file descriptor.
"""Get the socket file descriptor.
@return The file descriptor number.
"""
return self.__sock.fileno()
def is_scm_available(self) -> bool:
"""
Check if the socket allows for SCM_RIGHTS.
"""Check if the socket allows for SCM_RIGHTS.
@return True if SCM_RIGHTS is available, otherwise False.
"""

View File

@@ -41,7 +41,11 @@ def users_folder(flake_dir: Path, group: str) -> Path:
class Group:
def __init__(
self, flake_dir: Path, name: str, machines: list[str], users: list[str]
self,
flake_dir: Path,
name: str,
machines: list[str],
users: list[str],
) -> None:
self.name = name
self.machines = machines
@@ -235,13 +239,18 @@ def remove_machine_command(args: argparse.Namespace) -> None:
def add_group_argument(parser: argparse.ArgumentParser) -> None:
group_action = parser.add_argument(
"group", help="the name of the secret", type=group_name_type
"group",
help="the name of the secret",
type=group_name_type,
)
add_dynamic_completer(group_action, complete_groups)
def add_secret(
flake_dir: Path, group: str, name: str, age_plugins: list[str] | None
flake_dir: Path,
group: str,
name: str,
age_plugins: list[str] | None,
) -> None:
secrets.allow_member(
secrets.groups_folder(sops_secrets_folder(flake_dir) / name),
@@ -276,7 +285,10 @@ def add_secret_command(args: argparse.Namespace) -> None:
def remove_secret(
flake_dir: Path, group: str, name: str, age_plugins: list[str]
flake_dir: Path,
group: str,
name: str,
age_plugins: list[str],
) -> None:
updated_paths = secrets.disallow_member(
secrets.groups_folder(sops_secrets_folder(flake_dir) / name),
@@ -313,22 +325,28 @@ def register_groups_parser(parser: argparse.ArgumentParser) -> None:
# Add user
add_machine_parser = subparser.add_parser(
"add-machine", help="add a machine to group"
"add-machine",
help="add a machine to group",
)
add_group_argument(add_machine_parser)
add_machine_action = add_machine_parser.add_argument(
"machine", help="the name of the machines to add", type=machine_name_type
"machine",
help="the name of the machines to add",
type=machine_name_type,
)
add_dynamic_completer(add_machine_action, complete_machines)
add_machine_parser.set_defaults(func=add_machine_command)
# Remove machine
remove_machine_parser = subparser.add_parser(
"remove-machine", help="remove a machine from group"
"remove-machine",
help="remove a machine from group",
)
add_group_argument(remove_machine_parser)
remove_machine_action = remove_machine_parser.add_argument(
"machine", help="the name of the machines to remove", type=machine_name_type
"machine",
help="the name of the machines to remove",
type=machine_name_type,
)
add_dynamic_completer(remove_machine_action, complete_machines)
remove_machine_parser.set_defaults(func=remove_machine_command)
@@ -337,40 +355,51 @@ def register_groups_parser(parser: argparse.ArgumentParser) -> None:
add_user_parser = subparser.add_parser("add-user", help="add a user to group")
add_group_argument(add_user_parser)
add_user_action = add_user_parser.add_argument(
"user", help="the name of the user to add", type=user_name_type
"user",
help="the name of the user to add",
type=user_name_type,
)
add_dynamic_completer(add_user_action, complete_users)
add_user_parser.set_defaults(func=add_user_command)
# Remove user
remove_user_parser = subparser.add_parser(
"remove-user", help="remove a user from a group"
"remove-user",
help="remove a user from a group",
)
add_group_argument(remove_user_parser)
remove_user_action = remove_user_parser.add_argument(
"user", help="the name of the user to remove", type=user_name_type
"user",
help="the name of the user to remove",
type=user_name_type,
)
add_dynamic_completer(remove_user_action, complete_users)
remove_user_parser.set_defaults(func=remove_user_command)
# Add secret
add_secret_parser = subparser.add_parser(
"add-secret", help="allow a groups to access a secret"
"add-secret",
help="allow a groups to access a secret",
)
add_group_argument(add_secret_parser)
add_secret_action = add_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type
"secret",
help="the name of the secret",
type=secret_name_type,
)
add_dynamic_completer(add_secret_action, complete_secrets)
add_secret_parser.set_defaults(func=add_secret_command)
# Remove secret
remove_secret_parser = subparser.add_parser(
"remove-secret", help="remove a group's access to a secret"
"remove-secret",
help="remove a group's access to a secret",
)
add_group_argument(remove_secret_parser)
remove_secret_action = remove_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type
"secret",
help="the name of the secret",
type=secret_name_type,
)
add_dynamic_completer(remove_secret_action, complete_secrets)
remove_secret_parser.set_defaults(func=remove_secret_command)

View File

@@ -19,8 +19,7 @@ log = logging.getLogger(__name__)
def generate_key() -> sops.SopsKey:
"""
Generate a new age key and return it as a SopsKey.
"""Generate a new age key and return it as a SopsKey.
This function does not check if the key already exists.
It will generate a new key every time it is called.
@@ -28,14 +27,16 @@ def generate_key() -> sops.SopsKey:
Use 'check_key_exists' to check if a key already exists.
Before calling this function if you dont want to generate a new key.
"""
path = default_admin_private_key_path()
_, pub_key = generate_private_key(out_file=path)
log.info(
f"Generated age private key at '{path}' for your user.\nPlease back it up on a secure location or you will lose access to your secrets."
f"Generated age private key at '{path}' for your user.\nPlease back it up on a secure location or you will lose access to your secrets.",
)
return sops.SopsKey(
pub_key, username="", key_type=sops.KeyType.AGE, source=str(path)
pub_key,
username="",
key_type=sops.KeyType.AGE,
source=str(path),
)
@@ -49,7 +50,8 @@ def generate_command(args: argparse.Namespace) -> None:
key_type = key.key_type.name.lower()
print(f"{key.key_type.name} key {key.pubkey} is already set", file=sys.stderr)
print(
f"Add your {key_type} public key to the repository with:", file=sys.stderr
f"Add your {key_type} public key to the repository with:",
file=sys.stderr,
)
print(
f"clan secrets users add <username> --{key_type}-key {key.pubkey}",

View File

@@ -59,16 +59,12 @@ def get_machine_pubkey(flake_dir: Path, name: str) -> str:
def has_machine(flake_dir: Path, name: str) -> bool:
"""
Checks if a machine exists in the sops machines folder
"""
"""Checks if a machine exists in the sops machines folder"""
return (sops_machines_folder(flake_dir) / name / "key.json").exists()
def list_sops_machines(flake_dir: Path) -> list[str]:
"""
Lists all machines in the sops machines folder
"""
"""Lists all machines in the sops machines folder"""
path = sops_machines_folder(flake_dir)
def validate(name: str) -> bool:
@@ -97,7 +93,10 @@ def add_secret(
def remove_secret(
flake_dir: Path, machine: str, secret: str, age_plugins: list[str] | None
flake_dir: Path,
machine: str,
secret: str,
age_plugins: list[str] | None,
) -> None:
updated_paths = secrets.disallow_member(
secrets.machines_folder(sops_secrets_folder(flake_dir) / secret),
@@ -174,7 +173,9 @@ def register_machines_parser(parser: argparse.ArgumentParser) -> None:
default=False,
)
add_machine_action = add_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type
"machine",
help="the name of the machine",
type=machine_name_type,
)
add_dynamic_completer(add_machine_action, complete_machines)
add_parser.add_argument(
@@ -187,7 +188,9 @@ def register_machines_parser(parser: argparse.ArgumentParser) -> None:
# Parser
get_parser = subparser.add_parser("get", help="get a machine public key")
get_machine_parser = get_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type
"machine",
help="the name of the machine",
type=machine_name_type,
)
add_dynamic_completer(get_machine_parser, complete_machines)
get_parser.set_defaults(func=get_command)
@@ -195,35 +198,47 @@ def register_machines_parser(parser: argparse.ArgumentParser) -> None:
# Parser
remove_parser = subparser.add_parser("remove", help="remove a machine")
remove_machine_parser = remove_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type
"machine",
help="the name of the machine",
type=machine_name_type,
)
add_dynamic_completer(remove_machine_parser, complete_machines)
remove_parser.set_defaults(func=remove_command)
# Parser
add_secret_parser = subparser.add_parser(
"add-secret", help="allow a machine to access a secret"
"add-secret",
help="allow a machine to access a secret",
)
machine_add_secret_parser = add_secret_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type
"machine",
help="the name of the machine",
type=machine_name_type,
)
add_dynamic_completer(machine_add_secret_parser, complete_machines)
add_secret_action = add_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type
"secret",
help="the name of the secret",
type=secret_name_type,
)
add_dynamic_completer(add_secret_action, complete_secrets)
add_secret_parser.set_defaults(func=add_secret_command)
# Parser
remove_secret_parser = subparser.add_parser(
"remove-secret", help="remove a group's access to a secret"
"remove-secret",
help="remove a group's access to a secret",
)
machine_remove_parser = remove_secret_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type
"machine",
help="the name of the machine",
type=machine_name_type,
)
add_dynamic_completer(machine_remove_parser, complete_machines)
remove_secret_action = remove_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type
"secret",
help="the name of the secret",
type=secret_name_type,
)
add_dynamic_completer(remove_secret_action, complete_secrets)
remove_secret_parser.set_defaults(func=remove_secret_command)

View File

@@ -50,7 +50,8 @@ def list_generators_secrets(generators_path: Path) -> list[Path]:
return has_secret(generator_path / name)
for obj in list_objects(
generator_path, functools.partial(validate, generator_path)
generator_path,
functools.partial(validate, generator_path),
):
paths.append(generator_path / obj)
return paths
@@ -89,7 +90,7 @@ def update_secrets(
changed_files.extend(cleanup_dangling_symlinks(path / "groups"))
changed_files.extend(cleanup_dangling_symlinks(path / "machines"))
changed_files.extend(
update_keys(path, collect_keys_for_path(path), age_plugins=age_plugins)
update_keys(path, collect_keys_for_path(path), age_plugins=age_plugins),
)
return changed_files
@@ -120,7 +121,7 @@ def collect_keys_for_type(folder: Path) -> set[sops.SopsKey]:
kind = target.parent.name
if folder.name != kind:
log.warning(
f"Expected {p} to point to {folder} but points to {target.parent}"
f"Expected {p} to point to {folder} but points to {target.parent}",
)
continue
keys.update(read_keys(target))
@@ -160,7 +161,7 @@ def encrypt_secret(
admin_keys = sops.ensure_admin_public_keys(flake_dir)
if not admin_keys:
# todo double check the correct command to run
# TODO double check the correct command to run
msg = "No keys found. Please run 'clan secrets add-key' to add a key."
raise ClanError(msg)
@@ -179,7 +180,7 @@ def encrypt_secret(
user,
do_update_keys,
age_plugins=age_plugins,
)
),
)
for machine in add_machines:
@@ -190,7 +191,7 @@ def encrypt_secret(
machine,
do_update_keys,
age_plugins=age_plugins,
)
),
)
for group in add_groups:
@@ -201,7 +202,7 @@ def encrypt_secret(
group,
do_update_keys,
age_plugins=age_plugins,
)
),
)
recipient_keys = collect_keys_for_path(secret_path)
@@ -216,7 +217,7 @@ def encrypt_secret(
username,
do_update_keys,
age_plugins=age_plugins,
)
),
)
secret_path = secret_path / "secret"
@@ -310,13 +311,15 @@ def allow_member(
group_folder.parent,
collect_keys_for_path(group_folder.parent),
age_plugins=age_plugins,
)
),
)
return changed
def disallow_member(
group_folder: Path, name: str, age_plugins: list[str] | None
group_folder: Path,
name: str,
age_plugins: list[str] | None,
) -> list[Path]:
target = group_folder / name
if not target.exists():
@@ -349,7 +352,8 @@ def has_secret(secret_path: Path) -> bool:
def list_secrets(
flake_dir: Path, filter_fn: Callable[[str], bool] | None = None
flake_dir: Path,
filter_fn: Callable[[str], bool] | None = None,
) -> list[str]:
path = sops_secrets_folder(flake_dir)

View File

@@ -66,7 +66,7 @@ class KeyType(enum.Enum):
for public_key in get_public_age_keys(content):
log.debug(
f"Found age public key from a private key "
f"in {key_path}: {public_key}"
f"in {key_path}: {public_key}",
)
keyring.append(
@@ -75,7 +75,7 @@ class KeyType(enum.Enum):
username="",
key_type=self,
source=str(key_path),
)
),
)
except ClanError as e:
error_msg = f"Failed to read age keys from {key_path}"
@@ -96,7 +96,7 @@ class KeyType(enum.Enum):
for public_key in get_public_age_keys(content):
log.debug(
f"Found age public key from a private key "
f"in the environment (SOPS_AGE_KEY): {public_key}"
f"in the environment (SOPS_AGE_KEY): {public_key}",
)
keyring.append(
@@ -105,7 +105,7 @@ class KeyType(enum.Enum):
username="",
key_type=self,
source="SOPS_AGE_KEY",
)
),
)
except ClanError as e:
error_msg = "Failed to read age keys from SOPS_AGE_KEY"
@@ -126,8 +126,11 @@ class KeyType(enum.Enum):
log.debug(msg)
keyring.append(
SopsKey(
pubkey=fp, username="", key_type=self, source="SOPS_PGP_FP"
)
pubkey=fp,
username="",
key_type=self,
source="SOPS_PGP_FP",
),
)
return keyring
@@ -389,7 +392,7 @@ def get_user_name(flake_dir: Path, user: str) -> str:
"""Ask the user for their name until a unique one is provided."""
while True:
name = input(
f"Your key is not yet added to the repository. Enter your user name for which your sops key will be stored in the repository [default: {user}]: "
f"Your key is not yet added to the repository. Enter your user name for which your sops key will be stored in the repository [default: {user}]: ",
)
if name:
user = name
@@ -455,7 +458,9 @@ def ensure_admin_public_keys(flake_dir: Path) -> set[SopsKey]:
def update_keys(
secret_path: Path, keys: Iterable[SopsKey], age_plugins: list[str] | None = None
secret_path: Path,
keys: Iterable[SopsKey],
age_plugins: list[str] | None = None,
) -> list[Path]:
secret_path = secret_path / "secret"
error_msg = f"Could not update keys for {secret_path}"
@@ -565,7 +570,7 @@ def get_recipients(secret_path: Path) -> set[SopsKey]:
username="",
key_type=key_type,
source="sops_file",
)
),
)
return keys

View File

@@ -66,7 +66,7 @@ def remove_user(flake_dir: Path, name: str) -> None:
continue
log.info(f"Removing user {name} from group {group}")
updated_paths.extend(
groups.remove_member(flake_dir, group.name, groups.users_folder, name)
groups.remove_member(flake_dir, group.name, groups.users_folder, name),
)
# Remove the user's key:
updated_paths.extend(remove_object(sops_users_folder(flake_dir), name))
@@ -96,7 +96,10 @@ def list_users(flake_dir: Path) -> list[str]:
def add_secret(
flake_dir: Path, user: str, secret: str, age_plugins: list[str] | None
flake_dir: Path,
user: str,
secret: str,
age_plugins: list[str] | None,
) -> None:
updated_paths = secrets.allow_member(
secrets.users_folder(sops_secrets_folder(flake_dir) / secret),
@@ -112,10 +115,15 @@ def add_secret(
def remove_secret(
flake_dir: Path, user: str, secret: str, age_plugins: list[str] | None
flake_dir: Path,
user: str,
secret: str,
age_plugins: list[str] | None,
) -> None:
updated_paths = secrets.disallow_member(
secrets.users_folder(sops_secrets_folder(flake_dir) / secret), user, age_plugins
secrets.users_folder(sops_secrets_folder(flake_dir) / secret),
user,
age_plugins,
)
commit_files(
updated_paths,
@@ -189,7 +197,7 @@ def _key_args(args: argparse.Namespace) -> Iterable[sops.SopsKey]:
]
if args.agekey:
age_keys.append(
sops.SopsKey(args.agekey, "", sops.KeyType.AGE, source="cmdline")
sops.SopsKey(args.agekey, "", sops.KeyType.AGE, source="cmdline"),
)
pgp_keys = [
@@ -260,7 +268,10 @@ def register_users_parser(parser: argparse.ArgumentParser) -> None:
add_parser = subparser.add_parser("add", help="add a user")
add_parser.add_argument(
"-f", "--force", help="overwrite existing user", action="store_true"
"-f",
"--force",
help="overwrite existing user",
action="store_true",
)
add_parser.add_argument("user", help="the name of the user", type=user_name_type)
_add_key_flags(add_parser)
@@ -268,59 +279,79 @@ def register_users_parser(parser: argparse.ArgumentParser) -> None:
get_parser = subparser.add_parser("get", help="get a user public key")
get_user_action = get_parser.add_argument(
"user", help="the name of the user", type=user_name_type
"user",
help="the name of the user",
type=user_name_type,
)
add_dynamic_completer(get_user_action, complete_users)
get_parser.set_defaults(func=get_command)
remove_parser = subparser.add_parser("remove", help="remove a user")
remove_user_action = remove_parser.add_argument(
"user", help="the name of the user", type=user_name_type
"user",
help="the name of the user",
type=user_name_type,
)
add_dynamic_completer(remove_user_action, complete_users)
remove_parser.set_defaults(func=remove_command)
add_secret_parser = subparser.add_parser(
"add-secret", help="allow a user to access a secret"
"add-secret",
help="allow a user to access a secret",
)
add_secret_user_action = add_secret_parser.add_argument(
"user", help="the name of the user", type=user_name_type
"user",
help="the name of the user",
type=user_name_type,
)
add_dynamic_completer(add_secret_user_action, complete_users)
add_secrets_action = add_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type
"secret",
help="the name of the secret",
type=secret_name_type,
)
add_dynamic_completer(add_secrets_action, complete_secrets)
add_secret_parser.set_defaults(func=add_secret_command)
remove_secret_parser = subparser.add_parser(
"remove-secret", help="remove a user's access to a secret"
"remove-secret",
help="remove a user's access to a secret",
)
remove_secret_user_action = remove_secret_parser.add_argument(
"user", help="the name of the group", type=user_name_type
"user",
help="the name of the group",
type=user_name_type,
)
add_dynamic_completer(remove_secret_user_action, complete_users)
remove_secrets_action = remove_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type
"secret",
help="the name of the secret",
type=secret_name_type,
)
add_dynamic_completer(remove_secrets_action, complete_secrets)
remove_secret_parser.set_defaults(func=remove_secret_command)
add_key_parser = subparser.add_parser(
"add-key", help="add one or more keys for a user"
"add-key",
help="add one or more keys for a user",
)
add_key_user_action = add_key_parser.add_argument(
"user", help="the name of the user", type=user_name_type
"user",
help="the name of the user",
type=user_name_type,
)
add_dynamic_completer(add_key_user_action, complete_users)
_add_key_flags(add_key_parser)
add_key_parser.set_defaults(func=add_key_command)
remove_key_parser = subparser.add_parser(
"remove-key", help="remove one or more keys for a user"
"remove-key",
help="remove one or more keys for a user",
)
remove_key_user_action = remove_key_parser.add_argument(
"user", help="the name of the user", type=user_name_type
"user",
help="the name of the user",
type=user_name_type,
)
add_dynamic_completer(remove_key_user_action, complete_users)
_add_key_flags(remove_key_parser)

View File

@@ -64,7 +64,8 @@ def ssh_command(args: argparse.Namespace) -> None:
ssh_options[name] = value
remote = remote.override(
host_key_check=args.host_key_check, ssh_options=ssh_options
host_key_check=args.host_key_check,
ssh_options=ssh_options,
)
if args.remote_command:
remote.interactive_ssh(args.remote_command)

View File

@@ -147,7 +147,7 @@ def test_ssh_shell_from_deploy(
str(success_txt),
"&&",
"exit 0",
]
],
)
assert success_txt.exists()

View File

@@ -25,7 +25,7 @@ def list_state_folders(machine: Machine, service: None | str = None) -> None:
[
f"{flake}#nixosConfigurations.{machine.name}.config.clan.core.state",
"--json",
]
],
)
res = "{}"
@@ -80,7 +80,7 @@ def list_state_folders(machine: Machine, service: None | str = None) -> None:
if post_restore:
print(f" postRestoreCommand: {post_restore}")
print("")
print()
def list_command(args: argparse.Namespace) -> None:

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.stdout import CaptureOutput
@pytest.mark.with_core
def test_state_list_vm1(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput
test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None:
with capture_output as output:
cli.run(["state", "list", "vm1", "--flake", str(test_flake_with_core.path)])
@@ -19,7 +20,8 @@ def test_state_list_vm1(
@pytest.mark.with_core
def test_state_list_vm2(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput
test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None:
with capture_output as output:
cli.run(["state", "list", "vm2", "--flake", str(test_flake_with_core.path)])

View File

@@ -15,7 +15,8 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
)
list_parser = subparser.add_parser("list", help="List available templates")
apply_parser = subparser.add_parser(
"apply", help="Apply a template of the specified type"
"apply",
help="Apply a template of the specified type",
)
register_list_parser(list_parser)
register_apply_parser(apply_parser)

View File

@@ -12,10 +12,11 @@ def list_command(args: argparse.Namespace) -> None:
# Display all templates
for i, (template_type, _builtin_template_set) in enumerate(
templates.builtins.items()
templates.builtins.items(),
):
builtin_template_set: TemplateClanType | None = templates.builtins.get(
template_type, None
template_type,
None,
) # type: ignore
if not builtin_template_set:
continue
@@ -32,7 +33,8 @@ def list_command(args: argparse.Namespace) -> None:
for i, (input_name, input_templates) in enumerate(templates.custom.items()):
custom_templates: TemplateClanType | None = input_templates.get(
template_type, None
template_type,
None,
) # type: ignore
if not custom_templates:
continue
@@ -48,11 +50,11 @@ def list_command(args: argparse.Namespace) -> None:
is_last_template = i == len(custom_templates.items()) - 1
if not is_last_template:
print(
f"{prefix} ├── {name}: {template.get('description', 'no description')}"
f"{prefix} ├── {name}: {template.get('description', 'no description')}",
)
else:
print(
f"{prefix} └── {name}: {template.get('description', 'no description')}"
f"{prefix} └── {name}: {template.get('description', 'no description')}",
)

View File

@@ -9,7 +9,8 @@ from clan_cli.tests.stdout import CaptureOutput
@pytest.mark.with_core
def test_templates_list(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput
test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None:
with capture_output as output:
cli.run(["templates", "list", "--flake", str(test_flake_with_core.path)])
@@ -26,7 +27,8 @@ def test_templates_list(
@pytest.mark.with_core
def test_templates_list_outside_clan(
capture_output: CaptureOutput, temp_dir: Path
capture_output: CaptureOutput,
temp_dir: Path,
) -> None:
"""Test templates list command when run outside a clan directory."""
with capture_output as output:

View File

@@ -37,7 +37,7 @@ class SopsSetup:
"--user",
self.user,
"--no-interactive",
]
],
)

View File

@@ -54,8 +54,7 @@ class Command:
@pytest.fixture
def command() -> Iterator[Command]:
"""
Starts a background command. The process is automatically terminated in the end.
"""Starts a background command. The process is automatically terminated in the end.
>>> p = command.run(["some", "daemon"])
>>> print(p.pid)
"""

View File

@@ -39,8 +39,7 @@ def def_value() -> defaultdict:
def nested_dict() -> defaultdict:
"""
Creates a defaultdict that allows for arbitrary levels of nesting.
"""Creates a defaultdict that allows for arbitrary levels of nesting.
For example: d['a']['b']['c'] = value
"""
return defaultdict(def_value)
@@ -75,7 +74,8 @@ def substitute(
if clan_core_replacement:
line = line.replace("__CLAN_CORE__", clan_core_replacement)
line = line.replace(
"git+https://git.clan.lol/clan/clan-core", clan_core_replacement
"git+https://git.clan.lol/clan/clan-core",
clan_core_replacement,
)
line = line.replace(
"https://git.clan.lol/clan/clan-core/archive/main.tar.gz",
@@ -133,8 +133,7 @@ def init_git(monkeypatch: pytest.MonkeyPatch, flake: Path) -> None:
class ClanFlake:
"""
This class holds all attributes for generating a clan flake.
"""This class holds all attributes for generating a clan flake.
For example, inventory and machine configs can be set via self.inventory and self.machines["my_machine"] = {...}.
Whenever a flake's configuration is changed, it needs to be re-generated by calling refresh().
@@ -179,7 +178,7 @@ class ClanFlake:
if not suppress_tmp_home_warning:
if "/tmp" not in str(os.environ.get("HOME")):
log.warning(
f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}"
f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}",
)
def copy(
@@ -236,7 +235,7 @@ class ClanFlake:
inventory_path = self.path / "inventory.json"
inventory_path.write_text(json.dumps(self.inventory, indent=2))
imports = "\n".join(
[f"clan-core.clanModules.{module}" for module in self.clan_modules]
[f"clan-core.clanModules.{module}" for module in self.clan_modules],
)
for machine_name, machine_config in self.machines.items():
configuration_nix = (
@@ -252,7 +251,7 @@ class ClanFlake:
{imports}
];
}}
"""
""",
)
machine = Machine(name=machine_name, flake=Flake(str(self.path)))
set_machine_settings(machine, machine_config)
@@ -309,8 +308,7 @@ def create_flake(
machine_configs: dict[str, dict] | None = None,
inventory_expr: str = r"{}",
) -> Iterator[FlakeForTest]:
"""
Creates a flake with the given name and machines.
"""Creates a flake with the given name and machines.
The machine names map to the machines in ./test_machines
"""
if machine_configs is None:
@@ -372,7 +370,7 @@ def create_flake(
if "/tmp" not in str(os.environ.get("HOME")):
log.warning(
f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}"
f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}",
)
init_git(monkeypatch, flake)
@@ -382,7 +380,8 @@ def create_flake(
@pytest.fixture
def test_flake(
monkeypatch: pytest.MonkeyPatch, temporary_home: Path
monkeypatch: pytest.MonkeyPatch,
temporary_home: Path,
) -> Iterator[FlakeForTest]:
yield from create_flake(
temporary_home=temporary_home,
@@ -429,8 +428,7 @@ def writable_clan_core(
clan_core: Path,
tmp_path: Path,
) -> Path:
"""
Creates a writable copy of clan_core in a temporary directory.
"""Creates a writable copy of clan_core in a temporary directory.
If clan_core is a git repo, copies tracked files and uncommitted changes.
Removes vars/ and sops/ directories if they exist.
"""
@@ -454,7 +452,9 @@ def writable_clan_core(
# Copy .git directory to maintain git functionality
if (clan_core / ".git").is_dir():
shutil.copytree(
clan_core / ".git", temp_flake / ".git", ignore_dangling_symlinks=True
clan_core / ".git",
temp_flake / ".git",
ignore_dangling_symlinks=True,
)
else:
# It's a git file (for submodules/worktrees)
@@ -478,9 +478,7 @@ def vm_test_flake(
clan_core: Path,
tmp_path: Path,
) -> Path:
"""
Creates a test flake that imports the VM test nixOS modules from clan-core.
"""
"""Creates a test flake that imports the VM test nixOS modules from clan-core."""
test_flake_dir = tmp_path / "test-flake"
test_flake_dir.mkdir(parents=True)

View File

@@ -18,7 +18,7 @@ def hosts(sshd: Sshd) -> list[Remote]:
private_key=Path(sshd.key),
host_key_check="none",
command_prefix="local_test",
)
),
]
return group

View File

@@ -13,31 +13,23 @@ else:
@pytest.fixture(scope="session")
def project_root() -> Path:
"""
Root directory the clan-cli
"""
"""Root directory the clan-cli"""
return PROJECT_ROOT
@pytest.fixture(scope="session")
def test_root() -> Path:
"""
Root directory of the tests
"""
"""Root directory of the tests"""
return TEST_ROOT
@pytest.fixture(scope="session")
def test_lib_root() -> Path:
"""
Root directory of the clan-lib tests
"""
"""Root directory of the clan-lib tests"""
return PROJECT_ROOT.parent / "clan_lib" / "tests"
@pytest.fixture(scope="session")
def clan_core() -> Path:
"""
Directory of the clan-core flake
"""
"""Directory of the clan-core flake"""
return CLAN_CORE

View File

@@ -29,7 +29,12 @@ class Sshd:
class SshdConfig:
def __init__(
self, path: Path, login_shell: Path, key: str, preload_lib: Path, log_file: Path
self,
path: Path,
login_shell: Path,
key: str,
preload_lib: Path,
log_file: Path,
) -> None:
self.path = path
self.login_shell = login_shell
@@ -53,7 +58,7 @@ def sshd_config(test_root: Path) -> Iterator[SshdConfig]:
sftp_server = sshdp.parent.parent / "libexec" / "sftp-server"
assert sftp_server is not None
content = string.Template(template).substitute(
{"host_key": host_key, "sftp_server": sftp_server}
{"host_key": host_key, "sftp_server": sftp_server},
)
config = tmpdir / "sshd_config"
config.write_text(content)
@@ -74,7 +79,7 @@ if [[ -f /etc/profile ]]; then
fi
export PATH="{bin_path}:{path}"
exec {bash} -l "${{@}}"
"""
""",
)
login_shell.chmod(0o755)
@@ -82,7 +87,7 @@ exec {bash} -l "${{@}}"
f"""#!{bash}
shift
exec "${{@}}"
"""
""",
)
fake_sudo.chmod(0o755)

View File

@@ -21,16 +21,17 @@ def should_skip(file_path: Path, excludes: list[Path]) -> bool:
def find_dataclasses_in_directory(
directory: Path, exclude_paths: list[str] | None = None
directory: Path,
exclude_paths: list[str] | None = None,
) -> list[tuple[Path, str]]:
"""
Find all dataclass classes in all Python files within a nested directory.
"""Find all dataclass classes in all Python files within a nested directory.
Args:
directory (str): The root directory to start searching from.
Returns:
List[Tuple[str, str]]: A list of tuples containing the file path and the dataclass name.
"""
if exclude_paths is None:
exclude_paths = []
@@ -69,10 +70,11 @@ def find_dataclasses_in_directory(
def load_dataclass_from_file(
file_path: Path, class_name: str, root_dir: str
file_path: Path,
class_name: str,
root_dir: str,
) -> type | None:
"""
Load a dataclass from a given file path.
"""Load a dataclass from a given file path.
Args:
file_path (str): Path to the file.
@@ -80,6 +82,7 @@ def load_dataclass_from_file(
Returns:
List[Type]: The dataclass type if found, else an empty list.
"""
module_name = (
os.path.relpath(file_path, root_dir).replace(os.path.sep, ".").rstrip(".py")
@@ -109,15 +112,14 @@ def load_dataclass_from_file(
dataclass_type = getattr(module, class_name, None)
if dataclass_type and is_dataclass(dataclass_type):
return cast(type, dataclass_type)
return cast("type", dataclass_type)
msg = f"Could not load dataclass {class_name} from file: {file_path}"
raise ClanError(msg)
def test_all_dataclasses() -> None:
"""
This Test ensures that all dataclasses are compatible with the API.
"""This Test ensures that all dataclasses are compatible with the API.
It will load all dataclasses from the clan_cli directory and
generate a JSON schema for each of them.
@@ -125,7 +127,6 @@ def test_all_dataclasses() -> None:
It will fail if any dataclass cannot be converted to JSON schema.
This means the dataclass in its current form is not compatible with the API.
"""
# Excludes:
# - API includes Type Generic wrappers, that are not known in the init file.
excludes = [

View File

@@ -14,5 +14,5 @@ def test_backups(
"--flake",
str(test_flake_with_core.path),
"vm1",
]
],
)

View File

@@ -139,7 +139,7 @@ def test_create_flake_fallback_from_non_clan_directory(
monkeypatch.setenv("LOGNAME", "testuser")
cli.run(
["flakes", "create", str(new_clan_dir), "--template=default", "--no-update"]
["flakes", "create", str(new_clan_dir), "--template=default", "--no-update"],
)
assert (new_clan_dir / "flake.nix").exists()
@@ -157,7 +157,7 @@ def test_create_flake_with_local_template_reference(
# TODO: should error with: localFlake does not export myLocalTemplate clan template
cli.run(
["flakes", "create", str(new_clan_dir), "--template=.#default", "--no-update"]
["flakes", "create", str(new_clan_dir), "--template=.#default", "--no-update"],
)
assert (new_clan_dir / "flake.nix").exists()

View File

@@ -1,17 +1,13 @@
from typing import TYPE_CHECKING
import pytest
from clan_cli.tests.fixtures_flakes import FlakeForTest
from clan_cli.tests.helpers import cli
from clan_cli.tests.stdout import CaptureOutput
if TYPE_CHECKING:
pass
@pytest.mark.with_core
def test_flakes_inspect(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput
test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None:
with capture_output as output:
cli.run(
@@ -22,6 +18,6 @@ def test_flakes_inspect(
str(test_flake_with_core.path),
"--machine",
"vm1",
]
],
)
assert "Icon" in output.out

View File

@@ -19,7 +19,8 @@ def test_commit_file(git_repo: Path) -> None:
# check that the latest commit message is correct
assert (
subprocess.check_output(
["git", "log", "-1", "--pretty=%B"], cwd=git_repo
["git", "log", "-1", "--pretty=%B"],
cwd=git_repo,
).decode("utf-8")
== "test commit\n\n"
)
@@ -59,7 +60,8 @@ def test_clan_flake_in_subdir(git_repo: Path, monkeypatch: pytest.MonkeyPatch) -
# check that the latest commit message is correct
assert (
subprocess.check_output(
["git", "log", "-1", "--pretty=%B"], cwd=git_repo
["git", "log", "-1", "--pretty=%B"],
cwd=git_repo,
).decode("utf-8")
== "test commit\n\n"
)

View File

@@ -28,7 +28,7 @@ def test_import_sops(
str(test_flake_with_core.path),
"machine1",
age_keys[0].pubkey,
]
],
)
cli.run(
[
@@ -39,7 +39,7 @@ def test_import_sops(
str(test_flake_with_core.path),
"user1",
age_keys[1].pubkey,
]
],
)
cli.run(
[
@@ -50,7 +50,7 @@ def test_import_sops(
str(test_flake_with_core.path),
"user2",
age_keys[2].pubkey,
]
],
)
cli.run(
[
@@ -61,7 +61,7 @@ def test_import_sops(
str(test_flake_with_core.path),
"group1",
"user1",
]
],
)
cli.run(
[
@@ -72,7 +72,7 @@ def test_import_sops(
str(test_flake_with_core.path),
"group1",
"user2",
]
],
)
# To edit:
@@ -98,6 +98,6 @@ def test_import_sops(
with capture_output as output:
cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "secret-key"]
["secrets", "get", "--flake", str(test_flake_with_core.path), "secret-key"],
)
assert output.out == "secret-value"

View File

@@ -16,7 +16,7 @@ from clan_lib.persist.inventory_store import InventoryStore
"inventory_expr": r"""{
machines.jon = {};
machines.sara = {};
}"""
}""",
},
# TODO: Test
# - Function modules
@@ -38,14 +38,13 @@ from clan_lib.persist.inventory_store import InventoryStore
def test_inventory_deserialize_variants(
test_flake_with_core: FlakeForTest,
) -> None:
"""
Testing different inventory deserializations
"""Testing different inventory deserializations
Inventory should always be deserializable to a dict
"""
inventory_store = InventoryStore(Flake(str(test_flake_with_core.path)))
# Cast the inventory to a dict for the following assertions
inventory = cast(dict[str, Any], inventory_store.read())
inventory = cast("dict[str, Any]", inventory_store.read())
# Check that the inventory is a dict
assert isinstance(inventory, dict)

View File

@@ -27,7 +27,7 @@ def test_machine_subcommands(
"machine1",
"--tags",
"vm",
]
],
)
# Usually this is done by `inventory.write` but we created a separate flake object in the test that now holds stale data
inventory_store._flake.invalidate_cache()
@@ -47,7 +47,7 @@ def test_machine_subcommands(
assert "vm2" in output.out
cli.run(
["machines", "delete", "--flake", str(test_flake_with_core.path), "machine1"]
["machines", "delete", "--flake", str(test_flake_with_core.path), "machine1"],
)
# See comment above
inventory_store._flake.invalidate_cache()
@@ -105,7 +105,7 @@ def test_machines_update_nonexistent_machine(
"--flake",
str(test_flake_with_core.path),
"nonexistent-machine",
]
],
)
error_message = str(exc_info.value)
@@ -130,7 +130,7 @@ def test_machines_update_typo_in_machine_name(
"--flake",
str(test_flake_with_core.path),
"v1", # typo of "vm1"
]
],
)
error_message = str(exc_info.value)

View File

@@ -51,7 +51,7 @@ def _test_identities(
str(test_flake_with_core.path),
"foo",
age_keys[0].pubkey,
]
],
)
assert (sops_folder / what / "foo" / "key.json").exists()
@@ -64,7 +64,7 @@ def _test_identities(
str(test_flake_with_core.path),
"admin",
admin_age_key.pubkey,
]
],
)
with pytest.raises(ClanError): # raises "foo already exists"
@@ -77,7 +77,7 @@ def _test_identities(
str(test_flake_with_core.path),
"foo",
age_keys[0].pubkey,
]
],
)
with monkeypatch.context() as m:
@@ -93,7 +93,7 @@ def _test_identities(
f"--{what_singular}",
"foo",
test_secret_name,
]
],
)
assert_secrets_file_recipients(
@@ -114,7 +114,7 @@ def _test_identities(
"-f",
"foo",
age_keys[1].privkey,
]
],
)
assert_secrets_file_recipients(
test_flake_with_core.path,
@@ -131,7 +131,7 @@ def _test_identities(
"--flake",
str(test_flake_with_core.path),
"foo",
]
],
)
assert age_keys[1].pubkey in output.out
@@ -140,7 +140,7 @@ def _test_identities(
assert "foo" in output.out
cli.run(
["secrets", what, "remove", "--flake", str(test_flake_with_core.path), "foo"]
["secrets", what, "remove", "--flake", str(test_flake_with_core.path), "foo"],
)
assert not (sops_folder / what / "foo" / "key.json").exists()
@@ -153,7 +153,7 @@ def _test_identities(
"--flake",
str(test_flake_with_core.path),
"foo",
]
],
)
with capture_output as output:
@@ -178,7 +178,11 @@ def test_users(
) -> None:
with monkeypatch.context():
_test_identities(
"users", test_flake_with_core, capture_output, age_keys, monkeypatch
"users",
test_flake_with_core,
capture_output,
age_keys,
monkeypatch,
)
@@ -208,7 +212,7 @@ def test_multiple_user_keys(
str(test_flake_with_core.path),
user,
*[f"--age-key={key.pubkey}" for key in user_keys],
]
],
)
assert (sops_folder / "users" / user / "key.json").exists()
@@ -222,7 +226,7 @@ def test_multiple_user_keys(
"--flake",
str(test_flake_with_core.path),
user,
]
],
)
for user_key in user_keys:
@@ -249,7 +253,7 @@ def test_multiple_user_keys(
"--flake",
str(test_flake_with_core.path),
secret_name,
]
],
)
# check the secret has each of our user's keys as a recipient
@@ -268,7 +272,7 @@ def test_multiple_user_keys(
"--flake",
str(test_flake_with_core.path),
secret_name,
]
],
)
assert secret_value in output.out
@@ -295,7 +299,7 @@ def test_multiple_user_keys(
str(test_flake_with_core.path),
user,
key_to_remove.pubkey,
]
],
)
# check the secret has been updated
@@ -315,7 +319,7 @@ def test_multiple_user_keys(
str(test_flake_with_core.path),
user,
key_to_remove.pubkey,
]
],
)
# check the secret has been updated
@@ -334,7 +338,11 @@ def test_machines(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_test_identities(
"machines", test_flake_with_core, capture_output, age_keys, monkeypatch
"machines",
test_flake_with_core,
capture_output,
age_keys,
monkeypatch,
)
@@ -347,7 +355,7 @@ def test_groups(
) -> None:
with capture_output as output:
cli.run(
["secrets", "groups", "list", "--flake", str(test_flake_with_core.path)]
["secrets", "groups", "list", "--flake", str(test_flake_with_core.path)],
)
assert output.out == ""
@@ -365,7 +373,7 @@ def test_groups(
str(test_flake_with_core.path),
"group1",
"machine1",
]
],
)
with pytest.raises(ClanError): # user does not exist yet
cli.run(
@@ -377,7 +385,7 @@ def test_groups(
str(test_flake_with_core.path),
"groupb1",
"user1",
]
],
)
cli.run(
[
@@ -388,7 +396,7 @@ def test_groups(
str(test_flake_with_core.path),
"machine1",
machine1_age_key.pubkey,
]
],
)
cli.run(
[
@@ -399,7 +407,7 @@ def test_groups(
str(test_flake_with_core.path),
"group1",
"machine1",
]
],
)
# Should this fail?
@@ -412,7 +420,7 @@ def test_groups(
str(test_flake_with_core.path),
"group1",
"machine1",
]
],
)
cli.run(
@@ -424,7 +432,7 @@ def test_groups(
str(test_flake_with_core.path),
"user1",
user1_age_key.pubkey,
]
],
)
cli.run(
[
@@ -435,7 +443,7 @@ def test_groups(
str(test_flake_with_core.path),
"admin",
admin_age_key.pubkey,
]
],
)
cli.run(
[
@@ -446,12 +454,12 @@ def test_groups(
str(test_flake_with_core.path),
"group1",
"user1",
]
],
)
with capture_output as output:
cli.run(
["secrets", "groups", "list", "--flake", str(test_flake_with_core.path)]
["secrets", "groups", "list", "--flake", str(test_flake_with_core.path)],
)
out = output.out
assert "user1" in out
@@ -472,7 +480,7 @@ def test_groups(
"--group",
"group1",
secret_name,
]
],
)
assert_secrets_file_recipients(
@@ -498,7 +506,7 @@ def test_groups(
str(test_flake_with_core.path),
"group1",
"user1",
]
],
)
assert_secrets_file_recipients(
test_flake_with_core.path,
@@ -520,7 +528,7 @@ def test_groups(
str(test_flake_with_core.path),
"group1",
"user1",
]
],
)
assert_secrets_file_recipients(
test_flake_with_core.path,
@@ -541,7 +549,7 @@ def test_groups(
"--flake",
str(test_flake_with_core.path),
"user1",
]
],
)
assert_secrets_file_recipients(
test_flake_with_core.path,
@@ -562,7 +570,7 @@ def test_groups(
str(test_flake_with_core.path),
"group1",
"machine1",
]
],
)
assert_secrets_file_recipients(
test_flake_with_core.path,
@@ -629,13 +637,15 @@ def test_secrets(
# Generate a new key for the clan
monkeypatch.setenv(
"SOPS_AGE_KEY_FILE", str(test_flake_with_core.path / ".." / "age.key")
"SOPS_AGE_KEY_FILE",
str(test_flake_with_core.path / ".." / "age.key"),
)
with patch(
"clan_cli.secrets.key.generate_private_key", wraps=generate_private_key
"clan_cli.secrets.key.generate_private_key",
wraps=generate_private_key,
) as spy:
cli.run(
["secrets", "key", "generate", "--flake", str(test_flake_with_core.path)]
["secrets", "key", "generate", "--flake", str(test_flake_with_core.path)],
)
assert spy.call_count == 1
@@ -655,18 +665,24 @@ def test_secrets(
str(test_flake_with_core.path),
"testuser",
key["publickey"],
]
],
)
with pytest.raises(ClanError): # does not exist yet
cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "nonexisting"]
[
"secrets",
"get",
"--flake",
str(test_flake_with_core.path),
"nonexisting",
],
)
monkeypatch.setenv("SOPS_NIX_SECRET", "foo")
cli.run(["secrets", "set", "--flake", str(test_flake_with_core.path), "initialkey"])
with capture_output as output:
cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "initialkey"]
["secrets", "get", "--flake", str(test_flake_with_core.path), "initialkey"],
)
assert output.out == "foo"
with capture_output as output:
@@ -684,7 +700,7 @@ def test_secrets(
"--flake",
str(test_flake_with_core.path),
"initialkey",
]
],
)
monkeypatch.delenv("EDITOR")
@@ -696,7 +712,7 @@ def test_secrets(
str(test_flake_with_core.path),
"initialkey",
"key",
]
],
)
with capture_output as output:
@@ -711,7 +727,7 @@ def test_secrets(
"--flake",
str(test_flake_with_core.path),
"nonexisting",
]
],
)
assert output.out == ""
@@ -730,7 +746,7 @@ def test_secrets(
str(test_flake_with_core.path),
"machine1",
age_keys[1].pubkey,
]
],
)
cli.run(
[
@@ -741,18 +757,18 @@ def test_secrets(
str(test_flake_with_core.path),
"machine1",
"key",
]
],
)
with capture_output as output:
cli.run(
["secrets", "machines", "list", "--flake", str(test_flake_with_core.path)]
["secrets", "machines", "list", "--flake", str(test_flake_with_core.path)],
)
assert output.out == "machine1\n"
with use_age_key(age_keys[1].privkey, monkeypatch):
with capture_output as output:
cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"]
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"],
)
assert output.out == "foo"
@@ -767,14 +783,14 @@ def test_secrets(
"-f",
"machine1",
age_keys[0].privkey,
]
],
)
# should also rotate the encrypted secret
with use_age_key(age_keys[0].privkey, monkeypatch):
with capture_output as output:
cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"]
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"],
)
assert output.out == "foo"
@@ -787,7 +803,7 @@ def test_secrets(
str(test_flake_with_core.path),
"machine1",
"key",
]
],
)
cli.run(
@@ -799,7 +815,7 @@ def test_secrets(
str(test_flake_with_core.path),
"user1",
age_keys[1].pubkey,
]
],
)
cli.run(
[
@@ -810,7 +826,7 @@ def test_secrets(
str(test_flake_with_core.path),
"user1",
"key",
]
],
)
with capture_output as output, use_age_key(age_keys[1].privkey, monkeypatch):
cli.run(["secrets", "get", "--flake", str(test_flake_with_core.path), "key"])
@@ -824,7 +840,7 @@ def test_secrets(
str(test_flake_with_core.path),
"user1",
"key",
]
],
)
with pytest.raises(ClanError): # does not exist yet
@@ -837,7 +853,7 @@ def test_secrets(
str(test_flake_with_core.path),
"admin-group",
"key",
]
],
)
cli.run(
[
@@ -848,7 +864,7 @@ def test_secrets(
str(test_flake_with_core.path),
"admin-group",
"user1",
]
],
)
cli.run(
[
@@ -859,7 +875,7 @@ def test_secrets(
str(test_flake_with_core.path),
"admin-group",
owner,
]
],
)
cli.run(
[
@@ -870,7 +886,7 @@ def test_secrets(
str(test_flake_with_core.path),
"admin-group",
"key",
]
],
)
cli.run(
@@ -882,13 +898,13 @@ def test_secrets(
"--group",
"admin-group",
"key2",
]
],
)
with use_age_key(age_keys[1].privkey, monkeypatch):
with capture_output as output:
cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"]
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"],
)
assert output.out == "foo"
@@ -903,7 +919,7 @@ def test_secrets(
"--pgp-key",
gpg_key.fingerprint,
"user2",
]
],
)
# Extend group will update secrets
@@ -916,13 +932,13 @@ def test_secrets(
str(test_flake_with_core.path),
"admin-group",
"user2",
]
],
)
with use_gpg_key(gpg_key, monkeypatch): # user2
with capture_output as output:
cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"]
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"],
)
assert output.out == "foo"
@@ -935,7 +951,7 @@ def test_secrets(
str(test_flake_with_core.path),
"admin-group",
"user2",
]
],
)
with (
pytest.raises(ClanError),
@@ -955,7 +971,7 @@ def test_secrets(
str(test_flake_with_core.path),
"admin-group",
"key",
]
],
)
cli.run(["secrets", "remove", "--flake", str(test_flake_with_core.path), "key"])
@@ -979,7 +995,8 @@ def test_secrets_key_generate_gpg(
with (
capture_output as output,
patch(
"clan_cli.secrets.key.generate_private_key", wraps=generate_private_key
"clan_cli.secrets.key.generate_private_key",
wraps=generate_private_key,
) as spy_sops,
):
cli.run(
@@ -989,7 +1006,7 @@ def test_secrets_key_generate_gpg(
"generate",
"--flake",
str(test_flake_with_core.path),
]
],
)
assert spy_sops.call_count == 0
# assert "age private key" not in output.out
@@ -1000,7 +1017,7 @@ def test_secrets_key_generate_gpg(
with capture_output as output:
cli.run(
["secrets", "key", "show", "--flake", str(test_flake_with_core.path)]
["secrets", "key", "show", "--flake", str(test_flake_with_core.path)],
)
key = json.loads(output.out)[0]
assert key["type"] == "pgp"
@@ -1017,7 +1034,7 @@ def test_secrets_key_generate_gpg(
"--pgp-key",
gpg_key.fingerprint,
"testuser",
]
],
)
with capture_output as output:
@@ -1029,7 +1046,7 @@ def test_secrets_key_generate_gpg(
"--flake",
str(test_flake_with_core.path),
"testuser",
]
],
)
keys = json.loads(output.out)
assert len(keys) == 1
@@ -1048,7 +1065,7 @@ def test_secrets_key_generate_gpg(
"--flake",
str(test_flake_with_core.path),
"secret-name",
]
],
)
with capture_output as output:
cli.run(
@@ -1058,7 +1075,7 @@ def test_secrets_key_generate_gpg(
"--flake",
str(test_flake_with_core.path),
"secret-name",
]
],
)
assert output.out == "secret-value"
@@ -1078,7 +1095,7 @@ def test_secrets_users_add_age_plugin_error(
str(test_flake_with_core.path),
"testuser",
"AGE-PLUGIN-YUBIKEY-18P5XCQVZ5FE4WKCW3NJWP",
]
],
)
error_msg = str(exc_info.value)

View File

@@ -31,7 +31,7 @@ def test_generate_secret(
str(test_flake_with_core.path),
"user1",
age_keys[0].pubkey,
]
],
)
cli.run(
[
@@ -42,7 +42,7 @@ def test_generate_secret(
str(test_flake_with_core.path),
"admins",
"user1",
]
],
)
cmd = [
"vars",
@@ -56,7 +56,7 @@ def test_generate_secret(
cli.run(cmd)
store1 = SecretStore(
Machine(name="vm1", flake=Flake(str(test_flake_with_core.path)))
Machine(name="vm1", flake=Flake(str(test_flake_with_core.path))),
)
assert store1.exists("", "age.key")
@@ -97,13 +97,13 @@ def test_generate_secret(
str(test_flake_with_core.path),
"--generator",
"zerotier",
]
],
)
assert age_key.lstat().st_mtime_ns == age_key_mtime
assert identity_secret.lstat().st_mtime_ns == secret1_mtime
store2 = SecretStore(
Machine(name="vm2", flake=Flake(str(test_flake_with_core.path)))
Machine(name="vm2", flake=Flake(str(test_flake_with_core.path))),
)
assert store2.exists("", "age.key")

View File

@@ -28,7 +28,10 @@ def test_run_environment(runtime: AsyncRuntime) -> None:
def test_run_local(runtime: AsyncRuntime) -> None:
p1 = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
None,
host.run_local,
["echo", "hello"],
RunOpts(log=Log.STDERR),
)
assert p1.wait().result.stdout == "hello\n"

View File

@@ -189,8 +189,8 @@ def test_generate_public_and_secret_vars(
nix_eval(
[
f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.value_with_default.value",
]
)
],
),
).stdout.strip()
assert json.loads(value_non_default) == "default_value"
@@ -211,14 +211,17 @@ def test_generate_public_and_secret_vars(
public_value = get_machine_var(machine, "my_generator/my_value").printable_value
assert public_value.startswith("public")
shared_value = get_machine_var(
machine, "my_shared_generator/my_shared_value"
machine,
"my_shared_generator/my_shared_value",
).printable_value
assert shared_value.startswith("shared")
vars_text = stringify_all_vars(machine)
flake_obj = Flake(str(flake.path))
my_generator = Generator("my_generator", machine="my_machine", _flake=flake_obj)
dependent_generator = Generator(
"dependent_generator", machine="my_machine", _flake=flake_obj
"dependent_generator",
machine="my_machine",
_flake=flake_obj,
)
in_repo_store = in_repo.FactStore(flake=flake_obj)
assert not in_repo_store.exists(my_generator, "my_secret")
@@ -235,8 +238,8 @@ def test_generate_public_and_secret_vars(
nix_eval(
[
f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.my_value.value",
]
)
],
),
).stdout.strip()
assert json.loads(vars_eval).startswith("public")
@@ -244,14 +247,14 @@ def test_generate_public_and_secret_vars(
nix_eval(
[
f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.value_with_default.value",
]
)
],
),
).stdout.strip()
assert json.loads(value_non_default).startswith("non-default")
# test regeneration works
cli.run(
["vars", "generate", "--flake", str(flake.path), "my_machine", "--regenerate"]
["vars", "generate", "--flake", str(flake.path), "my_machine", "--regenerate"],
)
# test regeneration without sandbox
cli.run(
@@ -263,7 +266,7 @@ def test_generate_public_and_secret_vars(
"my_machine",
"--regenerate",
"--no-sandbox",
]
],
)
# test stuff actually changed after regeneration
public_value_new = get_machine_var(machine, "my_generator/my_value").printable_value
@@ -273,7 +276,8 @@ def test_generate_public_and_secret_vars(
"Secret value should change after regeneration"
)
shared_value_new = get_machine_var(
machine, "my_shared_generator/my_shared_value"
machine,
"my_shared_generator/my_shared_value",
).printable_value
assert shared_value != shared_value_new, (
"Shared value should change after regeneration"
@@ -290,18 +294,20 @@ def test_generate_public_and_secret_vars(
"--no-sandbox",
"--generator",
"my_shared_generator",
]
],
)
# test that the shared generator is regenerated
shared_value_after_regeneration = get_machine_var(
machine, "my_shared_generator/my_shared_value"
machine,
"my_shared_generator/my_shared_value",
).printable_value
assert shared_value_after_regeneration != shared_value_new, (
"Shared value should change after regenerating my_shared_generator"
)
# test that the dependent generator is also regenerated (because it depends on my_shared_generator)
secret_value_after_regeneration = sops_store.get(
dependent_generator, "my_secret"
dependent_generator,
"my_secret",
).decode()
assert secret_value_after_regeneration != secret_value_new, (
"Dependent generator's secret should change after regenerating my_shared_generator"
@@ -311,7 +317,8 @@ def test_generate_public_and_secret_vars(
)
# test that my_generator is NOT regenerated (it doesn't depend on my_shared_generator)
public_value_after_regeneration = get_machine_var(
machine, "my_generator/my_value"
machine,
"my_generator/my_value",
).printable_value
assert public_value_after_regeneration == public_value_new, (
"my_generator value should NOT change after regenerating only my_shared_generator"
@@ -348,10 +355,14 @@ def test_generate_secret_var_sops_with_default_group(
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
flake_obj = Flake(str(flake.path))
first_generator = Generator(
"first_generator", machine="my_machine", _flake=flake_obj
"first_generator",
machine="my_machine",
_flake=flake_obj,
)
second_generator = Generator(
"second_generator", machine="my_machine", _flake=flake_obj
"second_generator",
machine="my_machine",
_flake=flake_obj,
)
in_repo_store = in_repo.FactStore(flake=flake_obj)
assert not in_repo_store.exists(first_generator, "my_secret")
@@ -372,16 +383,22 @@ def test_generate_secret_var_sops_with_default_group(
str(flake.path),
"user2",
pubkey_user2.pubkey,
]
],
)
cli.run(["secrets", "groups", "add-user", "my_group", "user2"])
# check if new user can access the secret
monkeypatch.setenv("USER", "user2")
first_generator_with_share = Generator(
"first_generator", share=False, machine="my_machine", _flake=flake_obj
"first_generator",
share=False,
machine="my_machine",
_flake=flake_obj,
)
second_generator_with_share = Generator(
"second_generator", share=False, machine="my_machine", _flake=flake_obj
"second_generator",
share=False,
machine="my_machine",
_flake=flake_obj,
)
assert sops_store.user_has_access("user2", first_generator_with_share, "my_secret")
assert sops_store.user_has_access("user2", second_generator_with_share, "my_secret")
@@ -398,7 +415,7 @@ def test_generate_secret_var_sops_with_default_group(
"--force",
"user2",
pubkey_user3.pubkey,
]
],
)
monkeypatch.setenv("USER", "user2")
assert sops_store.user_has_access("user2", first_generator_with_share, "my_secret")
@@ -438,10 +455,16 @@ def test_generated_shared_secret_sops(
m2_sops_store = sops.SecretStore(machine2.flake)
# Create generators with machine context for testing
generator_m1 = Generator(
"my_shared_generator", share=True, machine="machine1", _flake=machine1.flake
"my_shared_generator",
share=True,
machine="machine1",
_flake=machine1.flake,
)
generator_m2 = Generator(
"my_shared_generator", share=True, machine="machine2", _flake=machine2.flake
"my_shared_generator",
share=True,
machine="machine2",
_flake=machine2.flake,
)
assert m1_sops_store.exists(generator_m1, "my_shared_secret")
@@ -492,7 +515,9 @@ def test_generate_secret_var_password_store(
check=True,
)
subprocess.run(
["git", "config", "user.name", "Test User"], cwd=password_store_dir, check=True
["git", "config", "user.name", "Test User"],
cwd=password_store_dir,
check=True,
)
flake_obj = Flake(str(flake.path))
@@ -502,10 +527,18 @@ def test_generate_secret_var_password_store(
assert check_vars(machine.name, machine.flake)
store = password_store.SecretStore(flake=flake_obj)
my_generator = Generator(
"my_generator", share=False, files=[], machine="my_machine", _flake=flake_obj
"my_generator",
share=False,
files=[],
machine="my_machine",
_flake=flake_obj,
)
my_generator_shared = Generator(
"my_generator", share=True, files=[], machine="my_machine", _flake=flake_obj
"my_generator",
share=True,
files=[],
machine="my_machine",
_flake=flake_obj,
)
my_shared_generator = Generator(
"my_shared_generator",
@@ -538,7 +571,11 @@ def test_generate_secret_var_password_store(
assert "my_generator/my_secret" in vars_text
my_generator = Generator(
"my_generator", share=False, files=[], machine="my_machine", _flake=flake_obj
"my_generator",
share=False,
files=[],
machine="my_machine",
_flake=flake_obj,
)
var_name = "my_secret"
store.delete(my_generator, var_name)
@@ -547,7 +584,11 @@ def test_generate_secret_var_password_store(
store.delete_store("my_machine")
store.delete_store("my_machine") # check idempotency
my_generator2 = Generator(
"my_generator2", share=False, files=[], machine="my_machine", _flake=flake_obj
"my_generator2",
share=False,
files=[],
machine="my_machine",
_flake=flake_obj,
)
var_name = "my_secret2"
assert not store.exists(my_generator2, var_name)
@@ -686,9 +727,7 @@ def test_shared_vars_must_never_depend_on_machine_specific_vars(
monkeypatch: pytest.MonkeyPatch,
flake_with_sops: ClanFlake,
) -> None:
"""
Ensure that shared vars never depend on machine specific vars.
"""
"""Ensure that shared vars never depend on machine specific vars."""
flake = flake_with_sops
config = flake.machines["my_machine"]
@@ -719,8 +758,7 @@ def test_multi_machine_shared_vars(
monkeypatch: pytest.MonkeyPatch,
flake_with_sops: ClanFlake,
) -> None:
"""
Ensure that shared vars are regenerated only when they should, and also can be
"""Ensure that shared vars are regenerated only when they should, and also can be
accessed by all machines that should have access.
Specifically:
@@ -752,10 +790,16 @@ def test_multi_machine_shared_vars(
in_repo_store_2 = in_repo.FactStore(machine2.flake)
# Create generators with machine context for testing
generator_m1 = Generator(
"shared_generator", share=True, machine="machine1", _flake=machine1.flake
"shared_generator",
share=True,
machine="machine1",
_flake=machine1.flake,
)
generator_m2 = Generator(
"shared_generator", share=True, machine="machine2", _flake=machine2.flake
"shared_generator",
share=True,
machine="machine2",
_flake=machine2.flake,
)
# generate for machine 1
cli.run(["vars", "generate", "--flake", str(flake.path), "machine1"])
@@ -771,7 +815,7 @@ def test_multi_machine_shared_vars(
# ensure shared secret stays available for all machines after regeneration
# regenerate for machine 1
cli.run(
["vars", "generate", "--flake", str(flake.path), "machine1", "--regenerate"]
["vars", "generate", "--flake", str(flake.path), "machine1", "--regenerate"],
)
# ensure values changed
new_secret_1 = sops_store_1.get(generator_m1, "my_secret")
@@ -806,7 +850,7 @@ def test_api_set_prompts(
prompt_values={
"my_generator": {
"prompt1": "input1",
}
},
},
)
machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
@@ -820,14 +864,16 @@ def test_api_set_prompts(
prompt_values={
"my_generator": {
"prompt1": "input2",
}
},
},
)
assert store.get(my_generator, "prompt1").decode() == "input2"
machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
generators = get_generators(
machine=machine, full_closure=True, include_previous_values=True
machine=machine,
full_closure=True,
include_previous_values=True,
)
# get_generators should bind the store
assert generators[0].files[0]._store is not None
@@ -957,7 +1003,9 @@ def test_migration(
flake_obj = Flake(str(flake.path))
my_generator = Generator("my_generator", machine="my_machine", _flake=flake_obj)
other_generator = Generator(
"other_generator", machine="my_machine", _flake=flake_obj
"other_generator",
machine="my_machine",
_flake=flake_obj,
)
in_repo_store = in_repo.FactStore(flake=flake_obj)
sops_store = sops.SecretStore(flake=flake_obj)
@@ -1023,7 +1071,8 @@ def test_fails_when_files_are_left_from_other_backend(
@pytest.mark.with_core
def test_create_sops_age_secrets(
monkeypatch: pytest.MonkeyPatch, flake: ClanFlake
monkeypatch: pytest.MonkeyPatch,
flake: ClanFlake,
) -> None:
monkeypatch.chdir(flake.path)
cli.run(["vars", "keygen", "--flake", str(flake.path), "--user", "user"])
@@ -1111,7 +1160,7 @@ def test_dynamic_invalidation(
in {
clan.core.vars.generators.dependent_generator.validation = if builtins.pathExists p then builtins.readFile p else null;
}
"""
""",
)
flake.refresh()

View File

@@ -29,30 +29,30 @@ def test_vm_deployment(
nix_eval(
[
f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.sops.secrets",
]
)
).stdout.strip()
],
),
).stdout.strip(),
)
assert sops_secrets != {}
my_secret_path = run(
nix_eval(
[
f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.clan.core.vars.generators.m1_generator.files.my_secret.path",
]
)
],
),
).stdout.strip()
assert "no-such-path" not in my_secret_path
shared_secret_path = run(
nix_eval(
[
f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.clan.core.vars.generators.my_shared_generator.files.shared_secret.path",
]
)
],
),
).stdout.strip()
assert "no-such-path" not in shared_secret_path
vm1_config = inspect_vm(
machine=Machine("test-vm-deployment", Flake(str(vm_test_flake)))
machine=Machine("test-vm-deployment", Flake(str(vm_test_flake))),
)
with ExitStack() as stack:
vm1 = stack.enter_context(spawn_vm(vm1_config, stdin=subprocess.DEVNULL))
@@ -64,7 +64,7 @@ def test_vm_deployment(
assert result.stdout == "hello\n"
# check shared_secret is deployed
result = qga_m1.run(
["cat", "/run/secrets/vars/my_shared_generator/shared_secret"]
["cat", "/run/secrets/vars/my_shared_generator/shared_secret"],
)
assert result.stdout == "hello\n"
# check no_deploy_secret is not deployed

View File

@@ -17,7 +17,8 @@ no_kvm = not Path("/dev/kvm").exists()
@pytest.mark.with_core
def test_inspect(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput
test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None:
with capture_output as output:
cli.run(["vms", "inspect", "--flake", str(test_flake_with_core.path), "vm1"])
@@ -42,7 +43,7 @@ def test_run(
"add",
"user1",
age_keys[0].pubkey,
]
],
)
cli.run(
[
@@ -51,7 +52,7 @@ def test_run(
"add-user",
"admins",
"user1",
]
],
)
cli.run(
[
@@ -63,7 +64,7 @@ def test_run(
"shutdown",
"-h",
"now",
]
],
)
@@ -74,7 +75,7 @@ def test_vm_persistence(
) -> None:
# Use the pre-built test VM from the test flake
vm_config = inspect_vm(
machine=Machine("test-vm-persistence", Flake(str(vm_test_flake)))
machine=Machine("test-vm-persistence", Flake(str(vm_test_flake))),
)
with spawn_vm(vm_config) as vm, vm.qga_connect() as qga:

View File

@@ -62,9 +62,7 @@ class StoreBase(ABC):
var: "Var",
value: bytes,
) -> Path | None:
"""
override this method to implement the actual creation of the file
"""
"""Override this method to implement the actual creation of the file"""
@abstractmethod
def exists(self, generator: "Generator", name: str) -> bool:
@@ -81,8 +79,7 @@ class StoreBase(ABC):
generators: list["Generator"] | None = None,
file_name: str | None = None,
) -> str | None:
"""
Check the health of the store for the given machine and generators.
"""Check the health of the store for the given machine and generators.
This method detects any issues or inconsistencies in the store that may
require fixing (e.g., outdated encryption keys, missing permissions).
@@ -94,6 +91,7 @@ class StoreBase(ABC):
Returns:
str | None: An error message describing issues found, or None if everything is healthy
"""
return None
@@ -103,8 +101,7 @@ class StoreBase(ABC):
generators: list["Generator"] | None = None,
file_name: str | None = None,
) -> None:
"""
Fix any issues with the store for the given machine and generators.
"""Fix any issues with the store for the given machine and generators.
This method is intended to repair or update the store when inconsistencies
are detected (e.g., re-encrypting secrets with new keys, fixing permissions).
@@ -116,6 +113,7 @@ class StoreBase(ABC):
Returns:
None
"""
return
@@ -164,16 +162,15 @@ class StoreBase(ABC):
log_info = machine.info
if self.is_secret_store:
log.info(f"{action_str} secret var {generator.name}/{var.name}\n")
elif value != old_val:
msg = f"{action_str} var {generator.name}/{var.name}"
if not is_migration:
msg += f"\n old: {old_val_str}\n new: {string_repr(value)}"
log_info(msg)
else:
if value != old_val:
msg = f"{action_str} var {generator.name}/{var.name}"
if not is_migration:
msg += f"\n old: {old_val_str}\n new: {string_repr(value)}"
log_info(msg)
else:
log_info(
f"Var {generator.name}/{var.name} remains unchanged: {old_val_str}"
)
log_info(
f"Var {generator.name}/{var.name} remains unchanged: {old_val_str}",
)
return new_file
@abstractmethod
@@ -200,8 +197,7 @@ class StoreBase(ABC):
"""
def get_validation(self, generator: "Generator") -> str | None:
"""
Return the invalidation hash that indicates if a generator needs to be re-run
"""Return the invalidation hash that indicates if a generator needs to be re-run
due to a change in its definition
"""
hash_file = self.directory(generator, ".validation-hash")
@@ -210,17 +206,14 @@ class StoreBase(ABC):
return hash_file.read_text().strip()
def set_validation(self, generator: "Generator", hash_str: str) -> Path:
"""
Store the invalidation hash that indicates if a generator needs to be re-run
"""
"""Store the invalidation hash that indicates if a generator needs to be re-run"""
hash_file = self.directory(generator, ".validation-hash")
hash_file.parent.mkdir(parents=True, exist_ok=True)
hash_file.write_text(hash_str)
return hash_file
def hash_is_valid(self, generator: "Generator") -> bool:
"""
Check if the invalidation hash is up to date
"""Check if the invalidation hash is up to date
If the hash is not set in nix and hasn't been stored before, it is considered valid
-> this provides backward and forward compatibility
"""

View File

@@ -28,7 +28,9 @@ class VarStatus:
def vars_status(
machine_name: str, flake: Flake, generator_name: None | str = None
machine_name: str,
flake: Flake,
generator_name: None | str = None,
) -> VarStatus:
machine = Machine(name=machine_name, flake=flake)
missing_secret_vars = []
@@ -53,14 +55,14 @@ def vars_status(
for generator in generators:
for file in generator.files:
file.store(
machine.secret_vars_store if file.secret else machine.public_vars_store
machine.secret_vars_store if file.secret else machine.public_vars_store,
)
file.generator(generator)
if file.secret:
if not machine.secret_vars_store.exists(generator, file.name):
machine.info(
f"Secret var '{file.name}' for service '{generator.name}' in machine {machine.name} is missing."
f"Secret var '{file.name}' for service '{generator.name}' in machine {machine.name} is missing.",
)
missing_secret_vars.append(file)
else:
@@ -71,13 +73,13 @@ def vars_status(
)
if msg:
machine.info(
f"Secret var '{file.name}' for service '{generator.name}' in machine {machine.name} needs update: {msg}"
f"Secret var '{file.name}' for service '{generator.name}' in machine {machine.name} needs update: {msg}",
)
unfixed_secret_vars.append(file)
elif not machine.public_vars_store.exists(generator, file.name):
machine.info(
f"Public var '{file.name}' for service '{generator.name}' in machine {machine.name} is missing."
f"Public var '{file.name}' for service '{generator.name}' in machine {machine.name} is missing.",
)
missing_public_vars.append(file)
# check if invalidation hash is up to date
@@ -87,7 +89,7 @@ def vars_status(
):
invalid_generators.append(generator.name)
machine.info(
f"Generator '{generator.name}' in machine {machine.name} has outdated invalidation hash."
f"Generator '{generator.name}' in machine {machine.name} has outdated invalidation hash.",
)
return VarStatus(
missing_secret_vars,
@@ -98,7 +100,9 @@ def vars_status(
def check_vars(
machine_name: str, flake: Flake, generator_name: None | str = None
machine_name: str,
flake: Flake,
generator_name: None | str = None,
) -> bool:
status = vars_status(machine_name, flake, generator_name=generator_name)
return not (

View File

@@ -6,7 +6,8 @@ from clan_lib.errors import ClanError
def test_check_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -21,7 +21,7 @@ def generate_command(args: argparse.Namespace) -> None:
filter(
lambda m: m.name in args.machines,
machines,
)
),
)
# prefetch all vars
@@ -32,7 +32,7 @@ def generate_command(args: argparse.Namespace) -> None:
flake.precache(
[
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.core.vars.generators.*.validationHash",
]
],
)
run_generators(

View File

@@ -6,7 +6,8 @@ from clan_lib.errors import ClanError
def test_generate_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -19,7 +19,8 @@ log = logging.getLogger(__name__)
def dependencies_as_dir(
decrypted_dependencies: dict[str, dict[str, bytes]], tmpdir: Path
decrypted_dependencies: dict[str, dict[str, bytes]],
tmpdir: Path,
) -> None:
"""Helper function to create directory structure from decrypted dependencies."""
for dep_generator, files in decrypted_dependencies.items():
@@ -72,13 +73,15 @@ class Generator:
flake: "Flake",
include_previous_values: bool = False,
) -> list["Generator"]:
"""
Get all generators for a machine from the flake.
"""Get all generators for a machine from the flake.
Args:
machine_name (str): The name of the machine.
flake (Flake): The flake to get the generators from.
Returns:
list[Generator]: A list of (unsorted) generators for the machine.
"""
# Get all generator metadata in one select (safe fields only)
generators_data = flake.select_machine(
@@ -146,7 +149,8 @@ class Generator:
for generator in generators:
for prompt in generator.prompts:
prompt.previous_value = generator.get_previous_value(
machine, prompt
machine,
prompt,
)
return generators
@@ -175,8 +179,8 @@ class Generator:
machine = Machine(name=self.machine, flake=self._flake)
output = Path(
machine.select(
f'config.clan.core.vars.generators."{self.name}".finalScript'
)
f'config.clan.core.vars.generators."{self.name}".finalScript',
),
)
if tmp_store := nix_test_store():
output = tmp_store.joinpath(*output.parts[1:])
@@ -189,7 +193,7 @@ class Generator:
machine = Machine(name=self.machine, flake=self._flake)
return machine.select(
f'config.clan.core.vars.generators."{self.name}".validationHash'
f'config.clan.core.vars.generators."{self.name}".validationHash',
)
def decrypt_dependencies(
@@ -207,6 +211,7 @@ class Generator:
Returns:
Dictionary mapping generator names to their variable values
"""
from clan_lib.errors import ClanError
@@ -222,7 +227,8 @@ class Generator:
result[dep_key.name] = {}
dep_generator = next(
(g for g in generators if g.name == dep_key.name), None
(g for g in generators if g.name == dep_key.name),
None,
)
if dep_generator is None:
msg = f"Generator {dep_key.name} not found in machine {machine.name}"
@@ -237,11 +243,13 @@ class Generator:
for file in dep_files:
if file.secret:
result[dep_key.name][file.name] = secret_vars_store.get(
dep_generator, file.name
dep_generator,
file.name,
)
else:
result[dep_key.name][file.name] = public_vars_store.get(
dep_generator, file.name
dep_generator,
file.name,
)
return result
@@ -250,6 +258,7 @@ class Generator:
Returns:
Dictionary mapping prompt names to their values
"""
from .prompt import ask
@@ -275,6 +284,7 @@ class Generator:
machine: The machine to execute the generator for
prompt_values: Optional dictionary of prompt values. If not provided, prompts will be asked interactively.
no_sandbox: Whether to disable sandboxing when executing the generator
"""
import os
import sys
@@ -333,8 +343,8 @@ class Generator:
"--uid", "1000",
"--gid", "1000",
"--",
str(real_bash_path), "-c", generator
]
str(real_bash_path), "-c", generator,
],
)
# fmt: on
@@ -418,11 +428,11 @@ class Generator:
if validation is not None:
if public_changed:
files_to_commit.append(
machine.public_vars_store.set_validation(self, validation)
machine.public_vars_store.set_validation(self, validation),
)
if secret_changed:
files_to_commit.append(
machine.secret_vars_store.set_validation(self, validation)
machine.secret_vars_store.set_validation(self, validation),
)
commit_files(

View File

@@ -33,7 +33,7 @@ def get_machine_var(machine: Machine, var_id: str) -> Var:
raise ClanError(msg)
if len(results) > 1:
error = f"Found multiple vars for {var_id}:\n - " + "\n - ".join(
[str(var) for var in results]
[str(var) for var in results],
)
raise ClanError(error)
# we have exactly one result at this point

View File

@@ -72,7 +72,8 @@ def add_dependents(
def toposort_closure(
_closure: Iterable[GeneratorKey], generators: dict[GeneratorKey, Generator]
_closure: Iterable[GeneratorKey],
generators: dict[GeneratorKey, Generator],
) -> list[Generator]:
closure = set(_closure)
# return the topological sorted list of generators to execute
@@ -87,8 +88,7 @@ def toposort_closure(
# all generators in topological order
def full_closure(generators: dict[GeneratorKey, Generator]) -> list[Generator]:
"""
From a set of generators, return all generators in topological order.
"""From a set of generators, return all generators in topological order.
This includes all dependencies and dependents of the generators.
Returns all generators in topological order.
"""
@@ -97,8 +97,7 @@ def full_closure(generators: dict[GeneratorKey, Generator]) -> list[Generator]:
# just the missing generators including their dependents
def all_missing_closure(generators: dict[GeneratorKey, Generator]) -> list[Generator]:
"""
From a set of generators, return all incomplete generators in topological order.
"""From a set of generators, return all incomplete generators in topological order.
incomplete
: A generator is missing if at least one of its files is missing.
@@ -111,7 +110,8 @@ def all_missing_closure(generators: dict[GeneratorKey, Generator]) -> list[Gener
# only a selected list of generators including their missing dependencies and their dependents
def requested_closure(
requested_generators: list[GeneratorKey], generators: dict[GeneratorKey, Generator]
requested_generators: list[GeneratorKey],
generators: dict[GeneratorKey, Generator],
) -> list[Generator]:
closure = set(requested_generators)
# extend the graph to include all dependencies which are not on disk
@@ -123,7 +123,8 @@ def requested_closure(
# just enough to ensure that the list of selected generators are in a consistent state.
# empty if nothing is missing.
def minimal_closure(
requested_generators: list[GeneratorKey], generators: dict[GeneratorKey, Generator]
requested_generators: list[GeneratorKey],
generators: dict[GeneratorKey, Generator],
) -> list[Generator]:
closure = set(requested_generators)
final_closure = missing_dependency_closure(closure, generators)

View File

@@ -27,11 +27,11 @@ def _get_user_or_default(user: str | None) -> str:
# TODO: Unify with "create clan" should be done automatically
@API.register
def create_secrets_user(
flake_dir: Path, user: str | None = None, force: bool = False
flake_dir: Path,
user: str | None = None,
force: bool = False,
) -> None:
"""
initialize sops keys for vars
"""
"""Initialize sops keys for vars"""
user = _get_user_or_default(user)
pub_keys = maybe_get_admin_public_keys()
if not pub_keys:
@@ -51,11 +51,11 @@ def _select_keys_interactive(pub_keys: list[SopsKey]) -> list[SopsKey]:
selected_keys: list[SopsKey] = []
for i, key in enumerate(pub_keys):
log.info(
f"{i + 1}: type: {key.key_type}\n pubkey: {key.pubkey}\n source: {key.source}"
f"{i + 1}: type: {key.key_type}\n pubkey: {key.pubkey}\n source: {key.source}",
)
while not selected_keys:
choice = input(
"Select keys to use (comma-separated list of numbers, or leave empty to select all): "
"Select keys to use (comma-separated list of numbers, or leave empty to select all): ",
).strip()
if not choice:
log.info("No keys selected, using all keys.")
@@ -71,11 +71,11 @@ def _select_keys_interactive(pub_keys: list[SopsKey]) -> list[SopsKey]:
def create_secrets_user_interactive(
flake_dir: Path, user: str | None = None, force: bool = False
flake_dir: Path,
user: str | None = None,
force: bool = False,
) -> None:
"""
Initialize sops keys for vars interactively.
"""
"""Initialize sops keys for vars interactively."""
user = _get_user_or_default(user)
pub_keys = maybe_get_admin_public_keys()
if pub_keys:
@@ -83,13 +83,13 @@ def create_secrets_user_interactive(
pub_keys = _select_keys_interactive(pub_keys)
else:
log.info(
"\nNo admin keys found on this machine, generating a new key for sops."
"\nNo admin keys found on this machine, generating a new key for sops.",
)
pub_keys = [generate_key()]
# make sure the user backups the generated key
log.info("\n⚠️ IMPORTANT: Secret Key Backup ⚠️")
log.info(
"The generated key above is CRITICAL for accessing your clan's secrets."
"The generated key above is CRITICAL for accessing your clan's secrets.",
)
log.info("Without this key, you will lose access to all encrypted data!")
log.info("Please backup the key file immediately to a secure location.")
@@ -97,12 +97,12 @@ def create_secrets_user_interactive(
confirm = None
while not confirm or confirm.lower() != "y":
log.info(
"\nI have backed up the key file to a secure location. Confirm [y/N]: "
"\nI have backed up the key file to a secure location. Confirm [y/N]: ",
)
confirm = input().strip().lower()
if confirm != "y":
log.error(
"You must backup the key before proceeding. This is critical for data recovery!"
"You must backup the key before proceeding. This is critical for data recovery!",
)
# persist the generated or chosen admin pubkey in the repo
@@ -115,11 +115,11 @@ def create_secrets_user_interactive(
def create_secrets_user_auto(
flake_dir: Path, user: str | None = None, force: bool = False
flake_dir: Path,
user: str | None = None,
force: bool = False,
) -> None:
"""
Detect if the user is in interactive mode or not and choose the appropriate routine.
"""
"""Detect if the user is in interactive mode or not and choose the appropriate routine."""
if sys.stdin.isatty():
create_secrets_user_interactive(
flake_dir=flake_dir,
@@ -159,7 +159,10 @@ def register_keygen_parser(parser: argparse.ArgumentParser) -> None:
)
parser.add_argument(
"-f", "--force", help="overwrite existing user", action="store_true"
"-f",
"--force",
help="overwrite existing user",
action="store_true",
)
parser.add_argument(

View File

@@ -29,13 +29,13 @@ def _migration_file_exists(
if machine.secret_facts_store.exists(generator.name, fact_name):
return True
machine.debug(
f"Cannot migrate fact {fact_name} for service {generator.name}, as it does not exist in the secret fact store"
f"Cannot migrate fact {fact_name} for service {generator.name}, as it does not exist in the secret fact store",
)
if not is_secret:
if machine.public_facts_store.exists(generator.name, fact_name):
return True
machine.debug(
f"Cannot migrate fact {fact_name} for service {generator.name}, as it does not exist in the public fact store"
f"Cannot migrate fact {fact_name} for service {generator.name}, as it does not exist in the public fact store",
)
return False
@@ -59,14 +59,20 @@ def _migrate_file(
if file.secret:
old_value = machine.secret_facts_store.get(service_name, fact_name)
maybe_path = machine.secret_vars_store.set(
generator, file, old_value, is_migration=True
generator,
file,
old_value,
is_migration=True,
)
if maybe_path:
paths.append(maybe_path)
else:
old_value = machine.public_facts_store.get(service_name, fact_name)
maybe_path = machine.public_vars_store.set(
generator, file, old_value, is_migration=True
generator,
file,
old_value,
is_migration=True,
)
if maybe_path:
paths.append(maybe_path)
@@ -84,7 +90,11 @@ def migrate_files(
if _migration_file_exists(machine, generator, file.name):
assert generator.migrate_fact is not None
files_to_commit += _migrate_file(
machine, generator, file.name, generator.migrate_fact, file.name
machine,
generator,
file.name,
generator.migrate_fact,
file.name,
)
else:
not_found.append(file.name)
@@ -114,11 +124,10 @@ def check_can_migrate(
all_files_missing = False
else:
all_files_present = False
elif machine.public_vars_store.exists(generator, file.name):
all_files_missing = False
else:
if machine.public_vars_store.exists(generator, file.name):
all_files_missing = False
else:
all_files_present = False
all_files_present = False
if not all_files_present and not all_files_missing:
msg = f"Cannot migrate facts for generator {generator.name} as some files already exist in the store"
@@ -132,5 +141,5 @@ def check_can_migrate(
all(
_migration_file_exists(machine, generator, file.name)
for file in generator.files
)
),
)

View File

@@ -44,8 +44,8 @@ class Prompt:
"group": None,
"helperText": None,
"required": False,
}
)
},
),
)
@classmethod
@@ -60,13 +60,11 @@ class Prompt:
def get_multiline_hidden_input() -> str:
"""
Get multiline input from the user without echoing the input.
"""Get multiline input from the user without echoing the input.
This function allows the user to enter multiple lines of text,
and it will return the concatenated string of all lines entered.
The user can finish the input by pressing Ctrl-D (EOF).
"""
# Save terminal settings
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
@@ -136,7 +134,7 @@ def ask(
result = sys.stdin.read()
case PromptType.MULTILINE_HIDDEN:
print(
"Enter multiple lines (press Ctrl-D to finish or Ctrl-C to cancel):"
"Enter multiple lines (press Ctrl-D to finish or Ctrl-C to cancel):",
)
result = get_multiline_hidden_input()
case PromptType.HIDDEN:

View File

@@ -33,7 +33,11 @@ class SecretStore(StoreBase):
"""Get the password store directory, cached per machine."""
if not self._store_dir:
result = self._run_pass(
machine, "git", "rev-parse", "--show-toplevel", check=False
machine,
"git",
"rev-parse",
"--show-toplevel",
check=False,
)
if result.returncode != 0:
msg = "Password store must be a git repository"
@@ -43,7 +47,8 @@ class SecretStore(StoreBase):
def _pass_command(self, machine: str) -> str:
out_path = self.flake.select_machine(
machine, "config.clan.core.vars.password-store.passPackage.outPath"
machine,
"config.clan.core.vars.password-store.passPackage.outPath",
)
main_program = (
self.flake.select_machine(
@@ -133,13 +138,24 @@ class SecretStore(StoreBase):
result = self._run_pass(machine, "ls", str(machine_dir), check=False)
if result.returncode == 0:
self._run_pass(
machine, "rm", "--force", "--recursive", str(machine_dir), check=True
machine,
"rm",
"--force",
"--recursive",
str(machine_dir),
check=True,
)
return []
def generate_hash(self, machine: str) -> bytes:
result = self._run_pass(
machine, "git", "log", "-1", "--format=%H", self.entry_prefix, check=False
machine,
"git",
"log",
"-1",
"--format=%H",
self.entry_prefix,
check=False,
)
git_hash = result.stdout.strip()
@@ -183,7 +199,8 @@ class SecretStore(StoreBase):
vars_generators = Generator.get_machine_generators(machine, self.flake)
if "users" in phases:
with tarfile.open(
output_dir / "secrets_for_users.tar.gz", "w:gz"
output_dir / "secrets_for_users.tar.gz",
"w:gz",
) as user_tar:
for generator in vars_generators:
dir_exists = False
@@ -255,7 +272,8 @@ class SecretStore(StoreBase):
self.populate_dir(machine, pass_dir, phases)
upload_dir = Path(
self.flake.select_machine(
machine, "config.clan.core.vars.password-store.secretLocation"
)
machine,
"config.clan.core.vars.password-store.secretLocation",
),
)
upload(host, pass_dir, upload_dir)

View File

@@ -75,7 +75,8 @@ class SecretStore(StoreBase):
sops_secrets_folder(self.flake.path) / f"{machine}-age.key",
priv_key,
add_groups=self.flake.select_machine(
machine, "config.clan.core.sops.defaultGroups"
machine,
"config.clan.core.sops.defaultGroups",
),
age_plugins=load_age_plugins(self.flake),
)
@@ -86,7 +87,10 @@ class SecretStore(StoreBase):
return "sops"
def user_has_access(
self, user: str, generator: Generator, secret_name: str
self,
user: str,
generator: Generator,
secret_name: str,
) -> bool:
key_dir = sops_users_folder(self.flake.path) / user
return self.key_has_access(key_dir, generator, secret_name)
@@ -98,7 +102,10 @@ class SecretStore(StoreBase):
return self.key_has_access(key_dir, generator, secret_name)
def key_has_access(
self, key_dir: Path, generator: Generator, secret_name: str
self,
key_dir: Path,
generator: Generator,
secret_name: str,
) -> bool:
secret_path = self.secret_path(generator, secret_name)
recipient = sops.SopsKey.load_dir(key_dir)
@@ -115,8 +122,7 @@ class SecretStore(StoreBase):
generators: list[Generator] | None = None,
file_name: str | None = None,
) -> str | None:
"""
Check if SOPS secrets need to be re-encrypted due to recipient changes.
"""Check if SOPS secrets need to be re-encrypted due to recipient changes.
This method verifies that all secrets are properly encrypted with the current
set of recipient keys. It detects when new users or machines have been added
@@ -132,8 +138,8 @@ class SecretStore(StoreBase):
Raises:
ClanError: If the specified file_name is not found
"""
"""
if generators is None:
from clan_cli.vars.generator import Generator
@@ -185,7 +191,8 @@ class SecretStore(StoreBase):
value,
add_machines=[machine] if var.deploy else [],
add_groups=self.flake.select_machine(
machine, "config.clan.core.sops.defaultGroups"
machine,
"config.clan.core.sops.defaultGroups",
),
git_commit=False,
age_plugins=load_age_plugins(self.flake),
@@ -291,17 +298,18 @@ class SecretStore(StoreBase):
keys = collect_keys_for_path(path)
for group in self.flake.select_machine(
machine, "config.clan.core.sops.defaultGroups"
machine,
"config.clan.core.sops.defaultGroups",
):
keys.update(
collect_keys_for_type(
self.flake.path / "sops" / "groups" / group / "machines"
)
self.flake.path / "sops" / "groups" / group / "machines",
),
)
keys.update(
collect_keys_for_type(
self.flake.path / "sops" / "groups" / group / "users"
)
self.flake.path / "sops" / "groups" / group / "users",
),
)
return keys
@@ -329,8 +337,7 @@ class SecretStore(StoreBase):
generators: list[Generator] | None = None,
file_name: str | None = None,
) -> None:
"""
Fix sops secrets by re-encrypting them with the current set of recipient keys.
"""Fix sops secrets by re-encrypting them with the current set of recipient keys.
This method updates secrets when recipients have changed (e.g., new admin users
were added to the clan). It ensures all authorized recipients have access to the
@@ -343,6 +350,7 @@ class SecretStore(StoreBase):
Raises:
ClanError: If the specified file_name is not found
"""
from clan_cli.secrets.secrets import update_keys
@@ -368,7 +376,8 @@ class SecretStore(StoreBase):
gen_machine = self.get_machine(generator)
for group in self.flake.select_machine(
gen_machine, "config.clan.core.sops.defaultGroups"
gen_machine,
"config.clan.core.sops.defaultGroups",
):
allow_member(
groups_folder(secret_path),

View File

@@ -13,13 +13,17 @@ log = logging.getLogger(__name__)
def upload_secret_vars(machine: Machine, host: Host) -> None:
machine.secret_vars_store.upload(
machine.name, host, phases=["activation", "users", "services"]
machine.name,
host,
phases=["activation", "users", "services"],
)
def populate_secret_vars(machine: Machine, directory: Path) -> None:
machine.secret_vars_store.populate_dir(
machine.name, directory, phases=["activation", "users", "services"]
machine.name,
directory,
phases=["activation", "users", "services"],
)

View File

@@ -6,7 +6,8 @@ from clan_lib.errors import ClanError
def test_upload_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.chdir(tmp_path)

View File

@@ -13,6 +13,6 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
)
register_inspect_parser(
subparser.add_parser("inspect", help="inspect the vm configuration")
subparser.add_parser("inspect", help="inspect the vm configuration"),
)
register_run_parser(subparser.add_parser("run", help="run a VM from a machine"))

View File

@@ -204,7 +204,7 @@ def qemu_command(
"chardev=char0,mode=readline",
"-device",
"virtconsole,chardev=char0,nr=0",
]
],
)
else:
command.extend(
@@ -217,7 +217,7 @@ def qemu_command(
"virtconsole,chardev=char0,nr=0",
"-monitor",
"none",
]
],
)
vsock_cid = None

View File

@@ -43,14 +43,16 @@ def facts_to_nixos_config(facts: dict[str, dict[str, bytes]]) -> dict:
nixos_config["clan"]["core"]["secrets"][service]["facts"] = {}
for fact, value in service_facts.items():
nixos_config["clan"]["core"]["secrets"][service]["facts"][fact] = {
"value": value.decode()
"value": value.decode(),
}
return nixos_config
# TODO move this to the Machines class
def build_vm(
machine: Machine, tmpdir: Path, nix_options: list[str] | None = None
machine: Machine,
tmpdir: Path,
nix_options: list[str] | None = None,
) -> dict[str, str]:
# TODO pass prompt here for the GTK gui
if nix_options is None:
@@ -60,7 +62,7 @@ def build_vm(
output = Path(
machine.select(
"config.system.clan.vm.create",
)
),
)
if tmp_store := nix_test_store():
output = tmp_store.joinpath(*output.parts[1:])
@@ -129,7 +131,11 @@ def start_vm(
machine.debug(f"Starting VM with command: {cmd}")
with subprocess.Popen(
cmd, env=env, stdout=stdout, stderr=stderr, stdin=stdin
cmd,
env=env,
stdout=stdout,
stderr=stderr,
stdin=stdin,
) as process:
try:
yield process
@@ -222,7 +228,7 @@ def spawn_vm(
if cachedir is None:
cache_tmp = stack.enter_context(
TemporaryDirectory(prefix="vm-cache-", dir=cache)
TemporaryDirectory(prefix="vm-cache-", dir=cache),
)
cachedir = Path(cache_tmp)
@@ -403,7 +409,9 @@ def run_command(
def register_run_parser(parser: argparse.ArgumentParser) -> None:
machine_action = parser.add_argument(
"machine", type=str, help="machine in the flake to run"
"machine",
type=str,
help="machine in the flake to run",
)
add_dynamic_completer(machine_action, complete_machines)
# option: --publish 2222:22

View File

@@ -33,13 +33,13 @@ ResponseDataType = TypeVar("ResponseDataType")
class ProcessMessage(TypedDict):
"""
Represents a message to be sent to the UI.
"""Represents a message to be sent to the UI.
Attributes:
- topic: The topic of the message, used to identify the type of message.
- data: The data to be sent with the message.
- origin: The API operation that this message is related to, if applicable.
"""
topic: str
@@ -173,7 +173,7 @@ API.register(get_system_file)
message=e.msg,
description=e.description,
location=[fn.__name__, e.location],
)
),
],
)
except Exception as e:
@@ -186,7 +186,7 @@ API.register(get_system_file)
message=str(e),
description="An unexpected error occurred",
location=[fn.__name__],
)
),
],
)
@@ -292,7 +292,8 @@ API.register(get_system_file)
def import_all_modules_from_package(pkg: ModuleType) -> None:
for _loader, module_name, _is_pkg in pkgutil.walk_packages(
pkg.__path__, prefix=f"{pkg.__name__}."
pkg.__path__,
prefix=f"{pkg.__name__}.",
):
base_name = module_name.split(".")[-1]
@@ -308,8 +309,7 @@ def import_all_modules_from_package(pkg: ModuleType) -> None:
def load_in_all_api_functions() -> None:
"""
For the global API object, to have all functions available.
"""For the global API object, to have all functions available.
We have to make sure python loads every wrapped function at least once.
This is done by importing all modules from the clan_lib and clan_cli packages.
"""

View File

@@ -32,8 +32,7 @@ class FileRequest:
@API.register_abstract
def get_system_file(file_request: FileRequest) -> list[str] | None:
"""
Api method to open a file dialog window.
"""Api method to open a file dialog window.
Implementations is specific to the platform and
returns the name of the selected file or None if no file was selected.
@@ -44,8 +43,7 @@ def get_system_file(file_request: FileRequest) -> list[str] | None:
@API.register_abstract
def get_clan_folder() -> Flake:
"""
Api method to open the clan folder.
"""Api method to open the clan folder.
Implementations is specific to the platform and returns the path to the clan folder.
"""
@@ -85,13 +83,12 @@ def blk_from_dict(data: dict) -> BlkInfo:
@API.register
def list_system_storage_devices() -> Blockdevices:
"""
List local block devices by running `lsblk`.
"""List local block devices by running `lsblk`.
Returns:
A list of detected block devices with metadata like size, path, type, etc.
"""
"""
cmd = nix_shell(
["util-linux"],
[
@@ -107,14 +104,13 @@ def list_system_storage_devices() -> Blockdevices:
blk_info: dict[str, Any] = json.loads(res)
return Blockdevices(
blockdevices=[blk_from_dict(device) for device in blk_info["blockdevices"]]
blockdevices=[blk_from_dict(device) for device in blk_info["blockdevices"]],
)
@API.register
def get_clan_directory_relative(flake: Flake) -> str:
"""
Get the clan directory path relative to the flake root
"""Get the clan directory path relative to the flake root
from the clan.directory configuration setting.
Args:
@@ -125,6 +121,7 @@ def get_clan_directory_relative(flake: Flake) -> str:
Raises:
ClanError: If the flake evaluation fails or directories cannot be found
"""
from clan_lib.dirs import get_clan_directories
@@ -133,12 +130,13 @@ def get_clan_directory_relative(flake: Flake) -> str:
def get_clan_dir(flake: Flake) -> Path:
"""
Get the effective clan directory, respecting the clan.directory configuration.
"""Get the effective clan directory, respecting the clan.directory configuration.
Args:
flake: The clan flake
Returns:
Path to the effective clan directory
"""
relative_clan_dir = get_clan_directory_relative(flake)
return flake.path / relative_clan_dir if relative_clan_dir else flake.path

View File

@@ -29,7 +29,7 @@ def test_get_relative_clan_directory_custom(
{
directory = ./direct-config;
}
"""
""",
)
test_subdir = Path(flake.path) / "direct-config"
@@ -68,7 +68,7 @@ def test_get_clan_dir_custom(
{
directory = ./direct-config;
}
"""
""",
)
test_subdir = Path(flake.path) / "direct-config"

View File

@@ -90,8 +90,10 @@ def parse_avahi_output(output: str) -> DNSInfo:
@API.register
def list_system_services_mdns() -> DNSInfo:
"""List mDNS/DNS-SD services on the local network.
Returns:
DNSInfo: A dictionary containing discovered mDNS/DNS-SD services.
"""
cmd = nix_shell(
["avahi"],

View File

@@ -1,5 +1,4 @@
"""
This module provides utility functions for serialization and deserialization of data classes.
"""This module provides utility functions for serialization and deserialization of data classes.
Functions:
- sanitize_string(s: str) -> str: Ensures a string is properly escaped for json serializing.
@@ -56,9 +55,7 @@ def sanitize_string(s: str) -> str:
def is_enum(obj: Any) -> bool:
"""
Safely checks if the object or one of its attributes is an Enum.
"""
"""Safely checks if the object or one of its attributes is an Enum."""
# Check if the object itself is an Enum
if isinstance(obj, Enum):
return True
@@ -69,9 +66,7 @@ def is_enum(obj: Any) -> bool:
def get_enum_value(obj: Any) -> Any:
"""
Safely checks if the object or one of its attributes is an Enum.
"""
"""Safely checks if the object or one of its attributes is an Enum."""
# Check if the object itself is an Enum
value = getattr(obj, "value", None)
if value is None and obj.enum:
@@ -85,8 +80,7 @@ def get_enum_value(obj: Any) -> Any:
def dataclass_to_dict(obj: Any, *, use_alias: bool = True) -> Any:
"""
Converts objects to dictionaries.
"""Converts objects to dictionaries.
This function is round trip safe.
Meaning that if you convert the object to a dict and then back to a dataclass using 'from_dict'
@@ -103,8 +97,7 @@ def dataclass_to_dict(obj: Any, *, use_alias: bool = True) -> Any:
"""
def _to_dict(obj: Any) -> Any:
"""
Utility function to convert dataclasses to dictionaries
"""Utility function to convert dataclasses to dictionaries
It converts all nested dataclasses, lists, tuples, and dictionaries to dictionaries
It does NOT convert member functions.
@@ -115,7 +108,9 @@ def dataclass_to_dict(obj: Any, *, use_alias: bool = True) -> Any:
return {
# Use either the original name or name
sanitize_string(
field.metadata.get("alias", field.name) if use_alias else field.name
field.metadata.get("alias", field.name)
if use_alias
else field.name,
): _to_dict(getattr(obj, field.name))
for field in fields(obj)
if not field.name.startswith("_")
@@ -173,13 +168,11 @@ def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
def unwrap_none_type(type_hint: type | UnionType) -> type:
"""
Takes a type union and returns the first non-None type.
"""Takes a type union and returns the first non-None type.
None | str
=>
str
"""
if is_union_type(type_hint):
# Return the first non-None type
return next(t for t in get_args(type_hint) if t is not type(None))
@@ -191,10 +184,11 @@ JsonValue = str | float | dict[str, Any] | list[Any] | None
def construct_value(
t: type | UnionType, field_value: JsonValue, loc: list[str] | None = None
t: type | UnionType,
field_value: JsonValue,
loc: list[str] | None = None,
) -> Any:
"""
Construct a field value from a type hint and a field value.
"""Construct a field value from a type hint and a field value.
The following types are supported and matched in this order:
@@ -328,10 +322,11 @@ def construct_value(
def construct_dataclass[T: Any](
t: type[T], data: dict[str, Any], path: list[str] | None = None
t: type[T],
data: dict[str, Any],
path: list[str] | None = None,
) -> T:
"""
type t MUST be a dataclass
"""Type t MUST be a dataclass
Dynamically instantiate a data class from a dictionary, handling nested data classes.
Constructs the field values from the data dictionary using 'construct_value'
@@ -383,10 +378,11 @@ def construct_dataclass[T: Any](
def from_dict(
t: type | UnionType, data: dict[str, Any] | Any, path: list[str] | None = None
t: type | UnionType,
data: dict[str, Any] | Any,
path: list[str] | None = None,
) -> Any:
"""
Dynamically instantiate a data class from a dictionary, handling nested data classes.
"""Dynamically instantiate a data class from a dictionary, handling nested data classes.
This function is round trip safe in conjunction with 'dataclass_to_dict'
"""

View File

@@ -102,7 +102,9 @@ def test_nested_nullable() -> None:
mode="format",
disks={"main": "/dev/sda"},
system_config=SystemConfig(
language="en_US.UTF-8", keymap="en", ssh_keys_path=None
language="en_US.UTF-8",
keymap="en",
ssh_keys_path=None,
),
dry_run=False,
write_efi_boot_entries=False,
@@ -182,9 +184,7 @@ def test_alias_field() -> None:
def test_alias_field_from_orig_name() -> None:
"""
Field declares an alias. But the data is provided with the field name.
"""
"""Field declares an alias. But the data is provided with the field name."""
@dataclass
class Person:
@@ -197,10 +197,7 @@ def test_alias_field_from_orig_name() -> None:
def test_none_or_string() -> None:
"""
Field declares an alias. But the data is provided with the field name.
"""
"""Field declares an alias. But the data is provided with the field name."""
data = None
@dataclass
@@ -218,8 +215,7 @@ def test_none_or_string() -> None:
def test_union_with_none_edge_cases() -> None:
"""
Test various union types with None to ensure issubclass() error is avoided.
"""Test various union types with None to ensure issubclass() error is avoided.
This specifically tests the fix for the TypeError in is_type_in_union.
"""
# Test basic types with None

View File

@@ -7,7 +7,6 @@ from clan_lib.api import (
)
#
def test_sanitize_string() -> None:
# Simple strings
assert sanitize_string("Hello World") == "Hello World"

View File

@@ -44,7 +44,7 @@ def run_task_blocking(somearg: str) -> str:
log.debug("Task was cancelled")
return "Task was cancelled"
log.debug(
f"Processing {i} for {somearg}. ctx.should_cancel={ctx.should_cancel()}"
f"Processing {i} for {somearg}. ctx.should_cancel={ctx.should_cancel()}",
)
time.sleep(1)
return f"Task completed with argument: {somearg}"

View File

@@ -29,9 +29,7 @@ class JSchemaTypeError(Exception):
# Inspect the fields of the parameterized type
def inspect_dataclass_fields(t: type) -> dict[TypeVar, type]:
"""
Returns a map of type variables to actual types for a parameterized type.
"""
"""Returns a map of type variables to actual types for a parameterized type."""
origin = get_origin(t)
type_args = get_args(t)
if origin is None:
@@ -45,13 +43,12 @@ def inspect_dataclass_fields(t: type) -> dict[TypeVar, type]:
def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[str, Any]:
"""
Add metadata from typing.annotations to the json Schema.
"""Add metadata from typing.annotations to the json Schema.
The annotations can be a dict, a tuple, or a string and is directly applied to the schema as shown below.
No further validation is done, the caller is responsible for following json-schema.
Examples
--------
```python
# String annotation
Annotated[int, "This is an int"] -> {"type": "integer", "description": "This is an int"}
@@ -62,6 +59,7 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st
# Tuple annotation
Annotated[int, ("minimum", 0)] -> {"type": "integer", "minimum": 0}
```
"""
for annotation in annotations:
if isinstance(annotation, dict):
@@ -96,8 +94,7 @@ def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
def is_total(typed_dict_class: type) -> bool:
"""
Check if a TypedDict has total=true
"""Check if a TypedDict has total=true
https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-total-false
"""
return getattr(typed_dict_class, "__total__", True) # Default to True if not set
@@ -177,7 +174,9 @@ def type_to_dict(
explicit_required.add(field_name)
dict_properties[field_name] = type_to_dict(
field_type, f"{scope} {t.__name__}.{field_name}", type_map
field_type,
f"{scope} {t.__name__}.{field_name}",
type_map,
)
optional = set(dict_fields) - explicit_optional
@@ -195,7 +194,7 @@ def type_to_dict(
for arg in get_args(t):
try:
supported.append(
type_to_dict(arg, scope, type_map, narrow_unsupported_union_types)
type_to_dict(arg, scope, type_map, narrow_unsupported_union_types),
)
except JSchemaTypeError:
if narrow_unsupported_union_types:

View File

@@ -85,7 +85,7 @@ def test_simple_union_types() -> None:
"oneOf": [
{"type": "integer"},
{"type": "string"},
]
],
}
assert type_to_dict(int | str | float) == {
@@ -93,7 +93,7 @@ def test_simple_union_types() -> None:
{"type": "integer"},
{"type": "string"},
{"type": "number"},
]
],
}
assert type_to_dict(int | str | None) == {
@@ -101,7 +101,7 @@ def test_simple_union_types() -> None:
{"type": "integer"},
{"type": "string"},
{"type": "null"},
]
],
}
@@ -133,7 +133,7 @@ def test_complex_union_types() -> None:
"required": ["bar"],
},
{"type": "null"},
]
],
}
@@ -187,7 +187,7 @@ def test_dataclasses() -> None:
},
"additionalProperties": False,
"required": [
"name"
"name",
], # value is optional because it has a default value of None
}

View File

@@ -44,17 +44,14 @@ class AsyncResult[R]:
@property
def error(self) -> Exception | None:
"""
Returns an error if the callable raised an exception.
"""
"""Returns an error if the callable raised an exception."""
if isinstance(self._result, Exception):
return self._result
return None
@property
def result(self) -> R:
"""
Unwraps and returns the result if no exception occurred.
"""Unwraps and returns the result if no exception occurred.
Raises the exception otherwise.
"""
if isinstance(self._result, Exception):
@@ -64,9 +61,7 @@ class AsyncResult[R]:
@dataclass
class AsyncContext:
"""
This class stores thread-local data.
"""
"""This class stores thread-local data."""
prefix: str | None = None # prefix for logging
stdout: IO[bytes] | None = None # stdout of subprocesses
@@ -79,9 +74,7 @@ class AsyncContext:
@dataclass
class AsyncOpts:
"""
Options for the async_run function.
"""
"""Options for the async_run function."""
tid: str | None = None
check: bool = True
@@ -92,39 +85,29 @@ ASYNC_CTX_THREAD_LOCAL = threading.local()
def set_current_thread_opkey(op_key: str) -> None:
"""
Set the current thread's operation key.
"""
"""Set the current thread's operation key."""
ctx = get_async_ctx()
ctx.op_key = op_key
def get_current_thread_opkey() -> str | None:
"""
Get the current thread's operation key.
"""
"""Get the current thread's operation key."""
ctx = get_async_ctx()
return ctx.op_key
def is_async_cancelled() -> bool:
"""
Check if the current task has been cancelled.
"""
"""Check if the current task has been cancelled."""
return get_async_ctx().should_cancel()
def set_should_cancel(should_cancel: Callable[[], bool]) -> None:
"""
Set the cancellation function for the current task.
"""
"""Set the cancellation function for the current task."""
get_async_ctx().should_cancel = should_cancel
def get_async_ctx() -> AsyncContext:
"""
Retrieve the current AsyncContext, creating a new one if none exists.
"""
"""Retrieve the current AsyncContext, creating a new one if none exists."""
global ASYNC_CTX_THREAD_LOCAL
if not hasattr(ASYNC_CTX_THREAD_LOCAL, "async_ctx"):
@@ -155,9 +138,7 @@ class AsyncThread[**P, R](threading.Thread):
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
A threaded wrapper for running a function asynchronously.
"""
"""A threaded wrapper for running a function asynchronously."""
super().__init__()
self.function = function
self.args = args
@@ -169,9 +150,7 @@ class AsyncThread[**P, R](threading.Thread):
self.stop_event = stop_event # Event to signal cancellation
def run(self) -> None:
"""
Run the function in a separate thread.
"""
"""Run the function in a separate thread."""
try:
set_should_cancel(lambda: self.stop_event.is_set())
# Arguments for ParamSpec "P@AsyncThread" are missing
@@ -191,9 +170,7 @@ class AsyncFuture[R]:
_runtime: "AsyncRuntime"
def wait(self) -> AsyncResult[R]:
"""
Wait for the task to finish.
"""
"""Wait for the task to finish."""
if self._tid not in self._runtime.tasks:
msg = f"No task with the name '{self._tid}' exists."
raise ClanError(msg)
@@ -207,9 +184,7 @@ class AsyncFuture[R]:
return result
def get_result(self) -> AsyncResult[R] | None:
"""
Retrieve the result of a finished task and remove it from the task list.
"""
"""Retrieve the result of a finished task and remove it from the task list."""
if self._tid not in self._runtime.tasks:
msg = f"No task with the name '{self._tid}' exists."
raise ClanError(msg)
@@ -251,8 +226,7 @@ class AsyncRuntime:
*args: P.args,
**kwargs: P.kwargs,
) -> AsyncFuture[R]:
"""
Run the given function asynchronously in a thread with a specific name and arguments.
"""Run the given function asynchronously in a thread with a specific name and arguments.
The function's static typing is preserved.
"""
if opts is None:
@@ -268,7 +242,12 @@ class AsyncRuntime:
stop_event = threading.Event()
# Create and start the new AsyncThread
thread = AsyncThread(
opts, self.condition, stop_event, function, *args, **kwargs
opts,
self.condition,
stop_event,
function,
*args,
**kwargs,
)
self.tasks[opts.tid] = thread
thread.start()
@@ -282,17 +261,14 @@ class AsyncRuntime:
*args: P.args,
**kwargs: P.kwargs,
) -> AsyncFutureRef[R, Q]:
"""
The same as async_run, but with an additional reference to an object.
"""The same as async_run, but with an additional reference to an object.
This is useful to keep track of the origin of the task.
"""
future = self.async_run(opts, function, *args, **kwargs)
return AsyncFutureRef(_tid=future._tid, _runtime=self, ref=ref) # noqa: SLF001
def join_all(self) -> None:
"""
Wait for all tasks to finish
"""
"""Wait for all tasks to finish"""
with self.condition:
while any(
not task.finished for task in self.tasks.values()
@@ -300,9 +276,7 @@ class AsyncRuntime:
self.condition.wait() # Wait until a thread signals completion
def check_all(self) -> None:
"""
Check if there where any errors
"""
"""Check if there where any errors"""
err_count = 0
for name, task in self.tasks.items():
@@ -328,9 +302,7 @@ class AsyncRuntime:
raise ClanError(msg)
def __enter__(self) -> "AsyncRuntime":
"""
Enter the runtime context related to this object.
"""
"""Enter the runtime context related to this object."""
return self
def __exit__(
@@ -339,8 +311,7 @@ class AsyncRuntime:
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
"""
Exit the runtime context related to this object.
"""Exit the runtime context related to this object.
Sets async_ctx.cancel to True to signal cancellation.
"""
for name, task in self.tasks.items():

View File

@@ -5,7 +5,11 @@ from clan_lib.ssh.remote import Remote
def restore_service(
machine: Machine, host: Remote, name: str, provider: str, service: str
machine: Machine,
host: Remote,
name: str,
provider: str,
service: str,
) -> None:
backup_metadata = machine.select("config.clan.core.backups")
backup_folders = machine.select("config.clan.core.state")
@@ -73,5 +77,5 @@ def restore_backup(
errors.append(f"{service}: {e}")
if errors:
raise ClanError(
"Restore failed for the following services:\n" + "\n".join(errors)
"Restore failed for the following services:\n" + "\n".join(errors),
)

Some files were not shown because too many files have changed in this diff Show More