ruff: apply automatic fixes

This commit is contained in:
Jörg Thalheim
2025-08-20 13:52:45 +02:00
parent 798d445f3e
commit ea2d6aab65
217 changed files with 2283 additions and 1739 deletions

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """IPv6 address allocator for WireGuard networks.
IPv6 address allocator for WireGuard networks.
Network layout: Network layout:
- Base network: /40 ULA prefix (fd00::/8 + 32 bits from hash) - Base network: /40 ULA prefix (fd00::/8 + 32 bits from hash)
@@ -20,8 +19,7 @@ def hash_string(s: str) -> str:
def generate_ula_prefix(instance_name: str) -> ipaddress.IPv6Network: def generate_ula_prefix(instance_name: str) -> ipaddress.IPv6Network:
""" """Generate a /40 ULA prefix from instance name.
Generate a /40 ULA prefix from instance name.
Format: fd{32-bit hash}/40 Format: fd{32-bit hash}/40
This gives us fd00:0000:0000::/40 through fdff:ffff:ff00::/40 This gives us fd00:0000:0000::/40 through fdff:ffff:ff00::/40
@@ -46,10 +44,10 @@ def generate_ula_prefix(instance_name: str) -> ipaddress.IPv6Network:
def generate_controller_subnet( def generate_controller_subnet(
base_network: ipaddress.IPv6Network, controller_name: str base_network: ipaddress.IPv6Network,
controller_name: str,
) -> ipaddress.IPv6Network: ) -> ipaddress.IPv6Network:
""" """Generate a /56 subnet for a controller from the base /40 network.
Generate a /56 subnet for a controller from the base /40 network.
We have 16 bits (40 to 56) to allocate controller subnets. We have 16 bits (40 to 56) to allocate controller subnets.
This allows for 65,536 possible controller subnets. This allows for 65,536 possible controller subnets.
@@ -68,8 +66,7 @@ def generate_controller_subnet(
def generate_peer_suffix(peer_name: str) -> str: def generate_peer_suffix(peer_name: str) -> str:
""" """Generate a unique 64-bit host suffix for a peer.
Generate a unique 64-bit host suffix for a peer.
This suffix will be used in all controller subnets to create unique addresses. This suffix will be used in all controller subnets to create unique addresses.
Format: :xxxx:xxxx:xxxx:xxxx (64 bits) Format: :xxxx:xxxx:xxxx:xxxx (64 bits)
@@ -86,7 +83,7 @@ def generate_peer_suffix(peer_name: str) -> str:
def main() -> None: def main() -> None:
if len(sys.argv) < 4: if len(sys.argv) < 4:
print( print(
"Usage: ipv6_allocator.py <output_dir> <instance_name> <controller|peer> <machine_name>" "Usage: ipv6_allocator.py <output_dir> <instance_name> <controller|peer> <machine_name>",
) )
sys.exit(1) sys.exit(1)

View File

@@ -66,8 +66,7 @@ def render_option_header(name: str) -> str:
def join_lines_with_indentation(lines: list[str], indent: int = 4) -> str: def join_lines_with_indentation(lines: list[str], indent: int = 4) -> str:
""" """Joins multiple lines with a specified number of whitespace characters as indentation.
Joins multiple lines with a specified number of whitespace characters as indentation.
Args: Args:
lines (list of str): The lines of text to join. lines (list of str): The lines of text to join.
@@ -75,6 +74,7 @@ def join_lines_with_indentation(lines: list[str], indent: int = 4) -> str:
Returns: Returns:
str: The indented and concatenated string. str: The indented and concatenated string.
""" """
# Create the indentation string (e.g., four spaces) # Create the indentation string (e.g., four spaces)
indent_str = " " * indent indent_str = " " * indent
@@ -161,7 +161,10 @@ def render_option(
def print_options( def print_options(
options_file: str, head: str, no_options: str, replace_prefix: str | None = None options_file: str,
head: str,
no_options: str,
replace_prefix: str | None = None,
) -> str: ) -> str:
res = "" res = ""
with (Path(options_file) / "share/doc/nixos/options.json").open() as f: with (Path(options_file) / "share/doc/nixos/options.json").open() as f:
@@ -235,7 +238,7 @@ def produce_clan_core_docs() -> None:
for submodule_name, split_options in split.items(): for submodule_name, split_options in split.items():
outfile = f"{module_name}/{submodule_name}.md" outfile = f"{module_name}/{submodule_name}.md"
print( print(
f"[clan_core.{submodule_name}] Rendering option of: {submodule_name}... {outfile}" f"[clan_core.{submodule_name}] Rendering option of: {submodule_name}... {outfile}",
) )
init_level = 1 init_level = 1
root = options_to_tree(split_options, debug=True) root = options_to_tree(split_options, debug=True)
@@ -271,7 +274,8 @@ def produce_clan_core_docs() -> None:
def render_categories( def render_categories(
categories: list[str], categories_info: dict[str, CategoryInfo] categories: list[str],
categories_info: dict[str, CategoryInfo],
) -> str: ) -> str:
res = """<div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px;">""" res = """<div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px;">"""
for cat in categories: for cat in categories:
@@ -338,7 +342,8 @@ Learn how to use `clanServices` in practice in the [Using clanServices guide](..
# output += "## Categories\n\n" # output += "## Categories\n\n"
output += render_categories( output += render_categories(
module_info["manifest"]["categories"], ModuleManifest.categories_info() module_info["manifest"]["categories"],
ModuleManifest.categories_info(),
) )
output += f"{module_info['manifest']['readme']}\n" output += f"{module_info['manifest']['readme']}\n"
@@ -368,8 +373,7 @@ Learn how to use `clanServices` in practice in the [Using clanServices guide](..
def split_options_by_root(options: dict[str, Any]) -> dict[str, dict[str, Any]]: def split_options_by_root(options: dict[str, Any]) -> dict[str, dict[str, Any]]:
""" """Split the flat dictionary of options into a dict of which each entry will construct complete option trees.
Split the flat dictionary of options into a dict of which each entry will construct complete option trees.
{ {
"a": { Data } "a": { Data }
"a.b": { Data } "a.b": { Data }
@@ -453,9 +457,7 @@ def option_short_name(option_name: str) -> str:
def options_to_tree(options: dict[str, Any], debug: bool = False) -> Option: def options_to_tree(options: dict[str, Any], debug: bool = False) -> Option:
""" """Convert the options dictionary to a tree structure."""
Convert the options dictionary to a tree structure.
"""
# Helper function to create nested structure # Helper function to create nested structure
def add_to_tree(path_parts: list[str], info: Any, current_node: Option) -> None: def add_to_tree(path_parts: list[str], info: Any, current_node: Option) -> None:
@@ -507,22 +509,24 @@ def options_to_tree(options: dict[str, Any], debug: bool = False) -> Option:
def options_docs_from_tree( def options_docs_from_tree(
root: Option, init_level: int = 1, prefix: list[str] | None = None root: Option,
init_level: int = 1,
prefix: list[str] | None = None,
) -> str: ) -> str:
""" """Eender the options from the tree structure.
eender the options from the tree structure.
Args: Args:
root (Option): The root option node. root (Option): The root option node.
init_level (int): The initial level of indentation. init_level (int): The initial level of indentation.
prefix (list str): Will be printed as common prefix of all attribute names. prefix (list str): Will be printed as common prefix of all attribute names.
""" """
def render_tree(option: Option, level: int = init_level) -> str: def render_tree(option: Option, level: int = init_level) -> str:
output = "" output = ""
should_render = not option.name.startswith("<") and not option.name.startswith( should_render = not option.name.startswith("<") and not option.name.startswith(
"_" "_",
) )
if should_render: if should_render:
# short_name = option_short_name(option.name) # short_name = option_short_name(option.name)
@@ -547,7 +551,7 @@ def options_docs_from_tree(
return md return md
if __name__ == "__main__": # if __name__ == "__main__":
produce_clan_core_docs() produce_clan_core_docs()
produce_clan_service_author_docs() produce_clan_service_author_docs()

View File

@@ -32,11 +32,15 @@ def init_test_environment() -> None:
# Set up network bridge # Set up network bridge
subprocess.run( subprocess.run(
["ip", "link", "add", "br0", "type", "bridge"], check=True, text=True ["ip", "link", "add", "br0", "type", "bridge"],
check=True,
text=True,
) )
subprocess.run(["ip", "link", "set", "br0", "up"], check=True, text=True) subprocess.run(["ip", "link", "set", "br0", "up"], check=True, text=True)
subprocess.run( subprocess.run(
["ip", "addr", "add", "192.168.1.254/24", "dev", "br0"], check=True, text=True ["ip", "addr", "add", "192.168.1.254/24", "dev", "br0"],
check=True,
text=True,
) )
# Set up minimal passwd file for unprivileged operations # Set up minimal passwd file for unprivileged operations
@@ -111,8 +115,7 @@ def mount(
mountflags: int = 0, mountflags: int = 0,
data: str | None = None, data: str | None = None,
) -> None: ) -> None:
""" """A Python wrapper for the mount system call.
A Python wrapper for the mount system call.
:param source: The source of the file system (e.g., device name, remote filesystem). :param source: The source of the file system (e.g., device name, remote filesystem).
:param target: The mount point (an existing directory). :param target: The mount point (an existing directory).
@@ -129,7 +132,11 @@ def mount(
# Call the mount system call # Call the mount system call
result = libc.mount( result = libc.mount(
source_c, target_c, fstype_c, ctypes.c_ulong(mountflags), data_c source_c,
target_c,
fstype_c,
ctypes.c_ulong(mountflags),
data_c,
) )
if result != 0: if result != 0:
@@ -145,7 +152,7 @@ def prepare_machine_root(machinename: str, root: Path) -> None:
root.mkdir(parents=True, exist_ok=True) root.mkdir(parents=True, exist_ok=True)
root.joinpath("etc").mkdir(parents=True, exist_ok=True) root.joinpath("etc").mkdir(parents=True, exist_ok=True)
root.joinpath(".env").write_text( root.joinpath(".env").write_text(
"\n".join(f"{k}={v}" for k, v in os.environ.items()) "\n".join(f"{k}={v}" for k, v in os.environ.items()),
) )
@@ -157,7 +164,6 @@ def retry(fn: Callable, timeout: int = 900) -> None:
"""Call the given function repeatedly, with 1 second intervals, """Call the given function repeatedly, with 1 second intervals,
until it returns True or a timeout is reached. until it returns True or a timeout is reached.
""" """
for _ in range(timeout): for _ in range(timeout):
if fn(False): if fn(False):
return return
@@ -284,8 +290,7 @@ class Machine:
check_output: bool = True, check_output: bool = True,
timeout: int | None = 900, timeout: int | None = 900,
) -> subprocess.CompletedProcess: ) -> subprocess.CompletedProcess:
""" """Execute a shell command, returning a list `(status, stdout)`.
Execute a shell command, returning a list `(status, stdout)`.
Commands are run with `set -euo pipefail` set: Commands are run with `set -euo pipefail` set:
@@ -316,7 +321,6 @@ class Machine:
`timeout` parameter, e.g., `execute(cmd, timeout=10)` or `timeout` parameter, e.g., `execute(cmd, timeout=10)` or
`execute(cmd, timeout=None)`. The default is 900 seconds. `execute(cmd, timeout=None)`. The default is 900 seconds.
""" """
# Always run command with shell opts # Always run command with shell opts
command = f"set -eo pipefail; source /etc/profile; set -xu; {command}" command = f"set -eo pipefail; source /etc/profile; set -xu; {command}"
@@ -330,7 +334,9 @@ class Machine:
return proc return proc
def nested( def nested(
self, msg: str, attrs: dict[str, str] | None = None self,
msg: str,
attrs: dict[str, str] | None = None,
) -> _GeneratorContextManager: ) -> _GeneratorContextManager:
if attrs is None: if attrs is None:
attrs = {} attrs = {}
@@ -339,8 +345,7 @@ class Machine:
return self.logger.nested(msg, my_attrs) return self.logger.nested(msg, my_attrs)
def systemctl(self, q: str) -> subprocess.CompletedProcess: def systemctl(self, q: str) -> subprocess.CompletedProcess:
""" """Runs `systemctl` commands with optional support for
Runs `systemctl` commands with optional support for
`systemctl --user` `systemctl --user`
```py ```py
@@ -355,8 +360,7 @@ class Machine:
return self.execute(f"systemctl {q}") return self.execute(f"systemctl {q}")
def wait_until_succeeds(self, command: str, timeout: int = 900) -> str: def wait_until_succeeds(self, command: str, timeout: int = 900) -> str:
""" """Repeat a shell command with 1-second intervals until it succeeds.
Repeat a shell command with 1-second intervals until it succeeds.
Has a default timeout of 900 seconds which can be modified, e.g. Has a default timeout of 900 seconds which can be modified, e.g.
`wait_until_succeeds(cmd, timeout=10)`. See `execute` for details on `wait_until_succeeds(cmd, timeout=10)`. See `execute` for details on
command execution. command execution.
@@ -374,18 +378,17 @@ class Machine:
return output return output
def wait_for_open_port( def wait_for_open_port(
self, port: int, addr: str = "localhost", timeout: int = 900 self,
port: int,
addr: str = "localhost",
timeout: int = 900,
) -> None: ) -> None:
""" """Wait for a port to be open on the given address."""
Wait for a port to be open on the given address.
"""
command = f"nc -z {shlex.quote(addr)} {port}" command = f"nc -z {shlex.quote(addr)} {port}"
self.wait_until_succeeds(command, timeout=timeout) self.wait_until_succeeds(command, timeout=timeout)
def wait_for_file(self, filename: str, timeout: int = 30) -> None: def wait_for_file(self, filename: str, timeout: int = 30) -> None:
""" """Waits until the file exists in the machine's file system."""
Waits until the file exists in the machine's file system.
"""
def check_file(_last_try: bool) -> bool: def check_file(_last_try: bool) -> bool:
result = self.execute(f"test -e {filename}") result = self.execute(f"test -e {filename}")
@@ -395,8 +398,7 @@ class Machine:
retry(check_file, timeout) retry(check_file, timeout)
def wait_for_unit(self, unit: str, timeout: int = 900) -> None: def wait_for_unit(self, unit: str, timeout: int = 900) -> None:
""" """Wait for a systemd unit to get into "active" state.
Wait for a systemd unit to get into "active" state.
Throws exceptions on "failed" and "inactive" states as well as after Throws exceptions on "failed" and "inactive" states as well as after
timing out. timing out.
""" """
@@ -441,9 +443,7 @@ class Machine:
return res.stdout return res.stdout
def shutdown(self) -> None: def shutdown(self) -> None:
""" """Shut down the machine, waiting for the VM to exit."""
Shut down the machine, waiting for the VM to exit.
"""
if self.process: if self.process:
self.process.terminate() self.process.terminate()
self.process.wait() self.process.wait()
@@ -557,7 +557,7 @@ class Driver:
rootdir=tempdir_path / container.name, rootdir=tempdir_path / container.name,
out_dir=self.out_dir, out_dir=self.out_dir,
logger=self.logger, logger=self.logger,
) ),
) )
def start_all(self) -> None: def start_all(self) -> None:
@@ -581,7 +581,7 @@ class Driver:
) )
print( print(
f"To attach to container {machine.name} run on the same machine that runs the test:" f"To attach to container {machine.name} run on the same machine that runs the test:",
) )
print( print(
" ".join( " ".join(
@@ -603,8 +603,8 @@ class Driver:
"-c", "-c",
"bash", "bash",
Style.RESET_ALL, Style.RESET_ALL,
] ],
) ),
) )
def test_symbols(self) -> dict[str, Any]: def test_symbols(self) -> dict[str, Any]:
@@ -623,7 +623,7 @@ class Driver:
"additionally exposed symbols:\n " "additionally exposed symbols:\n "
+ ", ".join(m.name for m in self.machines) + ", ".join(m.name for m in self.machines)
+ ",\n " + ",\n "
+ ", ".join(list(general_symbols.keys())) + ", ".join(list(general_symbols.keys())),
) )
return {**general_symbols, **machine_symbols} return {**general_symbols, **machine_symbols}

View File

@@ -25,14 +25,18 @@ class AbstractLogger(ABC):
@abstractmethod @abstractmethod
@contextmanager @contextmanager
def subtest( def subtest(
self, name: str, attributes: dict[str, str] | None = None self,
name: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
pass pass
@abstractmethod @abstractmethod
@contextmanager @contextmanager
def nested( def nested(
self, message: str, attributes: dict[str, str] | None = None self,
message: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
pass pass
@@ -66,7 +70,7 @@ class JunitXMLLogger(AbstractLogger):
def __init__(self, outfile: Path) -> None: def __init__(self, outfile: Path) -> None:
self.tests: dict[str, JunitXMLLogger.TestCaseState] = { self.tests: dict[str, JunitXMLLogger.TestCaseState] = {
"main": self.TestCaseState() "main": self.TestCaseState(),
} }
self.currentSubtest = "main" self.currentSubtest = "main"
self.outfile: Path = outfile self.outfile: Path = outfile
@@ -78,7 +82,9 @@ class JunitXMLLogger(AbstractLogger):
@contextmanager @contextmanager
def subtest( def subtest(
self, name: str, attributes: dict[str, str] | None = None self,
name: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
old_test = self.currentSubtest old_test = self.currentSubtest
self.tests.setdefault(name, self.TestCaseState()) self.tests.setdefault(name, self.TestCaseState())
@@ -90,7 +96,9 @@ class JunitXMLLogger(AbstractLogger):
@contextmanager @contextmanager
def nested( def nested(
self, message: str, attributes: dict[str, str] | None = None self,
message: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
self.log(message) self.log(message)
yield yield
@@ -144,7 +152,9 @@ class CompositeLogger(AbstractLogger):
@contextmanager @contextmanager
def subtest( def subtest(
self, name: str, attributes: dict[str, str] | None = None self,
name: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
with ExitStack() as stack: with ExitStack() as stack:
for logger in self.logger_list: for logger in self.logger_list:
@@ -153,7 +163,9 @@ class CompositeLogger(AbstractLogger):
@contextmanager @contextmanager
def nested( def nested(
self, message: str, attributes: dict[str, str] | None = None self,
message: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
with ExitStack() as stack: with ExitStack() as stack:
for logger in self.logger_list: for logger in self.logger_list:
@@ -200,19 +212,24 @@ class TerminalLogger(AbstractLogger):
@contextmanager @contextmanager
def subtest( def subtest(
self, name: str, attributes: dict[str, str] | None = None self,
name: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
with self.nested("subtest: " + name, attributes): with self.nested("subtest: " + name, attributes):
yield yield
@contextmanager @contextmanager
def nested( def nested(
self, message: str, attributes: dict[str, str] | None = None self,
message: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
self._eprint( self._eprint(
self.maybe_prefix( self.maybe_prefix(
Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL,
) attributes,
),
) )
tic = time.time() tic = time.time()
@@ -259,7 +276,9 @@ class XMLLogger(AbstractLogger):
return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C") return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
def maybe_prefix( def maybe_prefix(
self, message: str, attributes: dict[str, str] | None = None self,
message: str,
attributes: dict[str, str] | None = None,
) -> str: ) -> str:
if attributes and "machine" in attributes: if attributes and "machine" in attributes:
return f"{attributes['machine']}: {message}" return f"{attributes['machine']}: {message}"
@@ -309,14 +328,18 @@ class XMLLogger(AbstractLogger):
@contextmanager @contextmanager
def subtest( def subtest(
self, name: str, attributes: dict[str, str] | None = None self,
name: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
with self.nested("subtest: " + name, attributes): with self.nested("subtest: " + name, attributes):
yield yield
@contextmanager @contextmanager
def nested( def nested(
self, message: str, attributes: dict[str, str] | None = None self,
message: str,
attributes: dict[str, str] | None = None,
) -> Iterator[None]: ) -> Iterator[None]:
if attributes is None: if attributes is None:
attributes = {} attributes = {}

View File

@@ -195,7 +195,7 @@ def compute_zerotier_ip(network_id: str, identity: Identity) -> ipaddress.IPv6Ad
(node_id >> 16) & 0xFF, (node_id >> 16) & 0xFF,
(node_id >> 8) & 0xFF, (node_id >> 8) & 0xFF,
(node_id) & 0xFF, (node_id) & 0xFF,
] ],
) )
return ipaddress.IPv6Address(bytes(addr_parts)) return ipaddress.IPv6Address(bytes(addr_parts))
@@ -203,7 +203,10 @@ def compute_zerotier_ip(network_id: str, identity: Identity) -> ipaddress.IPv6Ad
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--mode", choices=["network", "identity"], required=True, type=str "--mode",
choices=["network", "identity"],
required=True,
type=str,
) )
parser.add_argument("--ip", type=Path, required=True) parser.add_argument("--ip", type=Path, required=True)
parser.add_argument("--identity-secret", type=Path, required=True) parser.add_argument("--identity-secret", type=Path, required=True)

View File

@@ -17,7 +17,7 @@ def main() -> None:
moon_json = json.loads(Path(moon_json_path).read_text()) moon_json = json.loads(Path(moon_json_path).read_text())
moon_json["roots"][0]["stableEndpoints"] = json.loads( moon_json["roots"][0]["stableEndpoints"] = json.loads(
Path(endpoint_config).read_text() Path(endpoint_config).read_text(),
) )
with NamedTemporaryFile("w") as f: with NamedTemporaryFile("w") as f:

View File

@@ -38,8 +38,7 @@ def get_gitea_api_url(remote: str = "origin") -> str:
host_and_path = remote_url.split("@")[1] # git.clan.lol:clan/clan-core.git host_and_path = remote_url.split("@")[1] # git.clan.lol:clan/clan-core.git
host = host_and_path.split(":")[0] # git.clan.lol host = host_and_path.split(":")[0] # git.clan.lol
repo_path = host_and_path.split(":")[1] # clan/clan-core.git repo_path = host_and_path.split(":")[1] # clan/clan-core.git
if repo_path.endswith(".git"): repo_path = repo_path.removesuffix(".git") # clan/clan-core
repo_path = repo_path[:-4] # clan/clan-core
elif remote_url.startswith("https://"): elif remote_url.startswith("https://"):
# HTTPS format: https://git.clan.lol/clan/clan-core.git # HTTPS format: https://git.clan.lol/clan/clan-core.git
url_parts = remote_url.replace("https://", "").split("/") url_parts = remote_url.replace("https://", "").split("/")
@@ -86,7 +85,10 @@ def get_repo_info_from_api_url(api_url: str) -> tuple[str, str]:
def fetch_pr_statuses( def fetch_pr_statuses(
repo_owner: str, repo_name: str, commit_sha: str, host: str repo_owner: str,
repo_name: str,
commit_sha: str,
host: str,
) -> list[dict]: ) -> list[dict]:
"""Fetch CI statuses for a specific commit SHA.""" """Fetch CI statuses for a specific commit SHA."""
status_url = ( status_url = (
@@ -183,7 +185,7 @@ def run_git_command(command: list) -> tuple[int, str, str]:
def get_current_branch_name() -> str: def get_current_branch_name() -> str:
exit_code, branch_name, error = run_git_command( exit_code, branch_name, error = run_git_command(
["git", "rev-parse", "--abbrev-ref", "HEAD"] ["git", "rev-parse", "--abbrev-ref", "HEAD"],
) )
if exit_code != 0: if exit_code != 0:
@@ -196,7 +198,7 @@ def get_current_branch_name() -> str:
def get_latest_commit_info() -> tuple[str, str]: def get_latest_commit_info() -> tuple[str, str]:
"""Get the title and body of the latest commit.""" """Get the title and body of the latest commit."""
exit_code, commit_msg, error = run_git_command( exit_code, commit_msg, error = run_git_command(
["git", "log", "-1", "--pretty=format:%B"] ["git", "log", "-1", "--pretty=format:%B"],
) )
if exit_code != 0: if exit_code != 0:
@@ -225,7 +227,7 @@ def get_commits_since_main() -> list[tuple[str, str]]:
"main..HEAD", "main..HEAD",
"--no-merges", "--no-merges",
"--pretty=format:%s|%b|---END---", "--pretty=format:%s|%b|---END---",
] ],
) )
if exit_code != 0: if exit_code != 0:
@@ -263,7 +265,9 @@ def open_editor_for_pr() -> tuple[str, str]:
commits_since_main = get_commits_since_main() commits_since_main = get_commits_since_main()
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
mode="w+", suffix="COMMIT_EDITMSG", delete=False mode="w+",
suffix="COMMIT_EDITMSG",
delete=False,
) as temp_file: ) as temp_file:
temp_file.flush() temp_file.flush()
temp_file_path = temp_file.name temp_file_path = temp_file.name
@@ -280,7 +284,7 @@ def open_editor_for_pr() -> tuple[str, str]:
temp_file.write("# The first line will be used as the PR title.\n") temp_file.write("# The first line will be used as the PR title.\n")
temp_file.write("# Everything else will be used as the PR description.\n") temp_file.write("# Everything else will be used as the PR description.\n")
temp_file.write( temp_file.write(
"# To abort creation of the PR, close editor with an error code.\n" "# To abort creation of the PR, close editor with an error code.\n",
) )
temp_file.write("# In vim for example you can use :cq!\n") temp_file.write("# In vim for example you can use :cq!\n")
temp_file.write("#\n") temp_file.write("#\n")
@@ -373,7 +377,7 @@ def create_agit_push(
print( print(
f" Description: {description[:50]}..." f" Description: {description[:50]}..."
if len(description) > 50 if len(description) > 50
else f" Description: {description}" else f" Description: {description}",
) )
print() print()
@@ -530,19 +534,26 @@ Examples:
) )
create_parser.add_argument( create_parser.add_argument(
"-t", "--topic", help="Set PR topic (default: current branch name)" "-t",
"--topic",
help="Set PR topic (default: current branch name)",
) )
create_parser.add_argument( create_parser.add_argument(
"--title", help="Set the PR title (default: last commit title)" "--title",
help="Set the PR title (default: last commit title)",
) )
create_parser.add_argument( create_parser.add_argument(
"--description", help="Override the PR description (default: commit body)" "--description",
help="Override the PR description (default: commit body)",
) )
create_parser.add_argument( create_parser.add_argument(
"-f", "--force", action="store_true", help="Force push the changes" "-f",
"--force",
action="store_true",
help="Force push the changes",
) )
create_parser.add_argument( create_parser.add_argument(

View File

@@ -13,7 +13,9 @@ log = logging.getLogger(__name__)
def main(argv: list[str] = sys.argv) -> int: def main(argv: list[str] = sys.argv) -> int:
parser = argparse.ArgumentParser(description="Clan App") parser = argparse.ArgumentParser(description="Clan App")
parser.add_argument( parser.add_argument(
"--content-uri", type=str, help="The URI of the content to display" "--content-uri",
type=str,
help="The URI of the content to display",
) )
parser.add_argument("--debug", action="store_true", help="Enable debug mode") parser.add_argument("--debug", action="store_true", help="Enable debug mode")
parser.add_argument( parser.add_argument(

View File

@@ -56,18 +56,23 @@ class ApiBridge(ABC):
for middleware in self.middleware_chain: for middleware in self.middleware_chain:
try: try:
log.debug( log.debug(
f"{middleware.__class__.__name__} => {request.method_name}" f"{middleware.__class__.__name__} => {request.method_name}",
) )
middleware.process(context) middleware.process(context)
except Exception as e: except Exception as e:
# If middleware fails, handle error # If middleware fails, handle error
self.send_api_error_response( self.send_api_error_response(
request.op_key or "unknown", str(e), ["middleware_error"] request.op_key or "unknown",
str(e),
["middleware_error"],
) )
return return
def send_api_error_response( def send_api_error_response(
self, op_key: str, error_message: str, location: list[str] self,
op_key: str,
error_message: str,
location: list[str],
) -> None: ) -> None:
"""Send an error response.""" """Send an error response."""
from clan_lib.api import ApiError, ErrorDataClass from clan_lib.api import ApiError, ErrorDataClass
@@ -80,7 +85,7 @@ class ApiBridge(ABC):
message="An internal error occured", message="An internal error occured",
description=error_message, description=error_message,
location=location, location=location,
) ),
], ],
) )
@@ -107,6 +112,7 @@ class ApiBridge(ABC):
thread_name: Name for the thread (for debugging) thread_name: Name for the thread (for debugging)
wait_for_completion: Whether to wait for the thread to complete wait_for_completion: Whether to wait for the thread to complete
timeout: Timeout in seconds when waiting for completion timeout: Timeout in seconds when waiting for completion
""" """
op_key = request.op_key or "unknown" op_key = request.op_key or "unknown"
@@ -116,7 +122,7 @@ class ApiBridge(ABC):
try: try:
log.debug( log.debug(
f"Processing {request.method_name} with args {request.args} " f"Processing {request.method_name} with args {request.args} "
f"and header {request.header} in thread {thread_name}" f"and header {request.header} in thread {thread_name}",
) )
self.process_request(request) self.process_request(request)
finally: finally:
@@ -124,7 +130,9 @@ class ApiBridge(ABC):
stop_event = threading.Event() stop_event = threading.Event()
thread = threading.Thread( thread = threading.Thread(
target=thread_task, args=(stop_event,), name=thread_name target=thread_task,
args=(stop_event,),
name=thread_name,
) )
thread.start() thread.start()
@@ -138,5 +146,7 @@ class ApiBridge(ABC):
if thread.is_alive(): if thread.is_alive():
stop_event.set() # Cancel the thread stop_event.set() # Cancel the thread
self.send_api_error_response( self.send_api_error_response(
op_key, "Request timeout", ["api_bridge", request.method_name] op_key,
"Request timeout",
["api_bridge", request.method_name],
) )

View File

@@ -26,8 +26,7 @@ RESULT: dict[str, SuccessDataClass[list[str] | None] | ErrorDataClass] = {}
def get_clan_folder() -> SuccessDataClass[Flake] | ErrorDataClass: def get_clan_folder() -> SuccessDataClass[Flake] | ErrorDataClass:
""" """Opens the clan folder using the GTK file dialog.
Opens the clan folder using the GTK file dialog.
Returns the path to the clan folder or an error if it fails. Returns the path to the clan folder or an error if it fails.
""" """
file_request = FileRequest( file_request = FileRequest(
@@ -52,7 +51,7 @@ def get_clan_folder() -> SuccessDataClass[Flake] | ErrorDataClass:
message="No folder selected", message="No folder selected",
description="You must select a folder to open.", description="You must select a folder to open.",
location=["get_clan_folder"], location=["get_clan_folder"],
) ),
], ],
) )
@@ -66,7 +65,7 @@ def get_clan_folder() -> SuccessDataClass[Flake] | ErrorDataClass:
message="Invalid clan folder", message="Invalid clan folder",
description=f"The selected folder '{clan_folder}' is not a valid clan folder.", description=f"The selected folder '{clan_folder}' is not a valid clan folder.",
location=["get_clan_folder"], location=["get_clan_folder"],
) ),
], ],
) )
@@ -102,8 +101,10 @@ def gtk_open_file(file_request: FileRequest, op_key: str) -> bool:
selected_path = remove_none([gfile.get_path()]) selected_path = remove_none([gfile.get_path()])
returns( returns(
SuccessDataClass( SuccessDataClass(
op_key=op_key, data=selected_path, status="success" op_key=op_key,
) data=selected_path,
status="success",
),
) )
except Exception as e: except Exception as e:
log.exception("Error opening file") log.exception("Error opening file")
@@ -116,9 +117,9 @@ def gtk_open_file(file_request: FileRequest, op_key: str) -> bool:
message=e.__class__.__name__, message=e.__class__.__name__,
description=str(e), description=str(e),
location=["get_system_file"], location=["get_system_file"],
) ),
], ],
) ),
) )
def on_file_select_multiple(file_dialog: Gtk.FileDialog, task: Gio.Task) -> None: def on_file_select_multiple(file_dialog: Gtk.FileDialog, task: Gio.Task) -> None:
@@ -128,8 +129,10 @@ def gtk_open_file(file_request: FileRequest, op_key: str) -> bool:
selected_paths = remove_none([gfile.get_path() for gfile in gfiles]) selected_paths = remove_none([gfile.get_path() for gfile in gfiles])
returns( returns(
SuccessDataClass( SuccessDataClass(
op_key=op_key, data=selected_paths, status="success" op_key=op_key,
) data=selected_paths,
status="success",
),
) )
else: else:
returns(SuccessDataClass(op_key=op_key, data=None, status="success")) returns(SuccessDataClass(op_key=op_key, data=None, status="success"))
@@ -144,9 +147,9 @@ def gtk_open_file(file_request: FileRequest, op_key: str) -> bool:
message=e.__class__.__name__, message=e.__class__.__name__,
description=str(e), description=str(e),
location=["get_system_file"], location=["get_system_file"],
) ),
], ],
) ),
) )
def on_folder_select(file_dialog: Gtk.FileDialog, task: Gio.Task) -> None: def on_folder_select(file_dialog: Gtk.FileDialog, task: Gio.Task) -> None:
@@ -156,8 +159,10 @@ def gtk_open_file(file_request: FileRequest, op_key: str) -> bool:
selected_path = remove_none([gfile.get_path()]) selected_path = remove_none([gfile.get_path()])
returns( returns(
SuccessDataClass( SuccessDataClass(
op_key=op_key, data=selected_path, status="success" op_key=op_key,
) data=selected_path,
status="success",
),
) )
else: else:
returns(SuccessDataClass(op_key=op_key, data=None, status="success")) returns(SuccessDataClass(op_key=op_key, data=None, status="success"))
@@ -172,9 +177,9 @@ def gtk_open_file(file_request: FileRequest, op_key: str) -> bool:
message=e.__class__.__name__, message=e.__class__.__name__,
description=str(e), description=str(e),
location=["get_system_file"], location=["get_system_file"],
) ),
], ],
) ),
) )
def on_save_finish(file_dialog: Gtk.FileDialog, task: Gio.Task) -> None: def on_save_finish(file_dialog: Gtk.FileDialog, task: Gio.Task) -> None:
@@ -184,8 +189,10 @@ def gtk_open_file(file_request: FileRequest, op_key: str) -> bool:
selected_path = remove_none([gfile.get_path()]) selected_path = remove_none([gfile.get_path()])
returns( returns(
SuccessDataClass( SuccessDataClass(
op_key=op_key, data=selected_path, status="success" op_key=op_key,
) data=selected_path,
status="success",
),
) )
else: else:
returns(SuccessDataClass(op_key=op_key, data=None, status="success")) returns(SuccessDataClass(op_key=op_key, data=None, status="success"))
@@ -200,9 +207,9 @@ def gtk_open_file(file_request: FileRequest, op_key: str) -> bool:
message=e.__class__.__name__, message=e.__class__.__name__,
description=str(e), description=str(e),
location=["get_system_file"], location=["get_system_file"],
) ),
], ],
) ),
) )
dialog = Gtk.FileDialog() dialog = Gtk.FileDialog()

View File

@@ -39,7 +39,7 @@ class ArgumentParsingMiddleware(Middleware):
except Exception as e: except Exception as e:
log.exception( log.exception(
f"Error while parsing arguments for {context.request.method_name}" f"Error while parsing arguments for {context.request.method_name}",
) )
context.bridge.send_api_error_response( context.bridge.send_api_error_response(
context.request.op_key or "unknown", context.request.op_key or "unknown",

View File

@@ -23,7 +23,9 @@ class Middleware(ABC):
"""Process the request through this middleware.""" """Process the request through this middleware."""
def register_context_manager( def register_context_manager(
self, context: MiddlewareContext, cm: AbstractContextManager[Any] self,
context: MiddlewareContext,
cm: AbstractContextManager[Any],
) -> Any: ) -> Any:
"""Register a context manager with the exit stack.""" """Register a context manager with the exit stack."""
return context.exit_stack.enter_context(cm) return context.exit_stack.enter_context(cm)

View File

@@ -25,23 +25,26 @@ class LoggingMiddleware(Middleware):
try: try:
# Handle log group configuration # Handle log group configuration
log_group: list[str] | None = context.request.header.get("logging", {}).get( log_group: list[str] | None = context.request.header.get("logging", {}).get(
"group_path", None "group_path",
None,
) )
if log_group is not None: if log_group is not None:
if not isinstance(log_group, list): if not isinstance(log_group, list):
msg = f"Expected log_group to be a list, got {type(log_group)}" msg = f"Expected log_group to be a list, got {type(log_group)}"
raise TypeError(msg) # noqa: TRY301 raise TypeError(msg) # noqa: TRY301
log.warning( log.warning(
f"Using log group {log_group} for {context.request.method_name} with op_key {context.request.op_key}" f"Using log group {log_group} for {context.request.method_name} with op_key {context.request.op_key}",
) )
# Create log file # Create log file
log_file = self.log_manager.create_log_file( log_file = self.log_manager.create_log_file(
method, op_key=context.request.op_key or "unknown", group_path=log_group method,
op_key=context.request.op_key or "unknown",
group_path=log_group,
).get_file_path() ).get_file_path()
except Exception as e: except Exception as e:
log.exception( log.exception(
f"Error while handling request header of {context.request.method_name}" f"Error while handling request header of {context.request.method_name}",
) )
context.bridge.send_api_error_response( context.bridge.send_api_error_response(
context.request.op_key or "unknown", context.request.op_key or "unknown",
@@ -76,7 +79,8 @@ class LoggingMiddleware(Middleware):
line_buffering=True, line_buffering=True,
) )
self.handler = setup_logging( self.handler = setup_logging(
log.getEffectiveLevel(), log_file=handler_stream log.getEffectiveLevel(),
log_file=handler_stream,
) )
return self return self

View File

@@ -32,7 +32,7 @@ class MethodExecutionMiddleware(Middleware):
except Exception as e: except Exception as e:
log.exception( log.exception(
f"Error while handling result of {context.request.method_name}" f"Error while handling result of {context.request.method_name}",
) )
context.bridge.send_api_error_response( context.bridge.send_api_error_response(
context.request.op_key or "unknown", context.request.op_key or "unknown",

View File

@@ -48,7 +48,7 @@ def app_run(app_opts: ClanAppOptions) -> int:
# Add a log group ["clans", <dynamic_name>, "machines", <dynamic_name>] # Add a log group ["clans", <dynamic_name>, "machines", <dynamic_name>]
log_manager = LogManager(base_dir=user_data_dir() / "clan-app" / "logs") log_manager = LogManager(base_dir=user_data_dir() / "clan-app" / "logs")
clan_log_group = LogGroupConfig("clans", "Clans").add_child( clan_log_group = LogGroupConfig("clans", "Clans").add_child(
LogGroupConfig("machines", "Machines") LogGroupConfig("machines", "Machines"),
) )
log_manager = log_manager.add_root_group_config(clan_log_group) log_manager = log_manager.add_root_group_config(clan_log_group)
# Init LogManager global in log_manager_api module # Init LogManager global in log_manager_api module
@@ -89,7 +89,7 @@ def app_run(app_opts: ClanAppOptions) -> int:
# HTTP-only mode - keep the server running # HTTP-only mode - keep the server running
log.info("HTTP API server running...") log.info("HTTP API server running...")
log.info( log.info(
f"Swagger: http://{app_opts.http_host}:{app_opts.http_port}/api/swagger" f"Swagger: http://{app_opts.http_host}:{app_opts.http_port}/api/swagger",
) )
log.info("Press Ctrl+C to stop the server") log.info("Press Ctrl+C to stop the server")

View File

@@ -63,7 +63,9 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
self.send_header("Access-Control-Allow-Headers", "Content-Type") self.send_header("Access-Control-Allow-Headers", "Content-Type")
def _send_json_response_with_status( def _send_json_response_with_status(
self, data: dict[str, Any], status_code: int = 200 self,
data: dict[str, Any],
status_code: int = 200,
) -> None: ) -> None:
"""Send a JSON response with the given status code.""" """Send a JSON response with the given status code."""
try: try:
@@ -82,11 +84,13 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
response_dict = dataclass_to_dict(response) response_dict = dataclass_to_dict(response)
self._send_json_response_with_status(response_dict, 200) self._send_json_response_with_status(response_dict, 200)
log.debug( log.debug(
f"HTTP response for {response._op_key}: {json.dumps(response_dict, indent=2)}" # noqa: SLF001 f"HTTP response for {response._op_key}: {json.dumps(response_dict, indent=2)}", # noqa: SLF001
) )
def _create_success_response( def _create_success_response(
self, op_key: str, data: dict[str, Any] self,
op_key: str,
data: dict[str, Any],
) -> BackendResponse: ) -> BackendResponse:
"""Create a successful API response.""" """Create a successful API response."""
return BackendResponse( return BackendResponse(
@@ -98,14 +102,16 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
def _send_info_response(self) -> None: def _send_info_response(self) -> None:
"""Send server information response.""" """Send server information response."""
response = self._create_success_response( response = self._create_success_response(
"info", {"message": "Clan API Server", "version": "1.0.0"} "info",
{"message": "Clan API Server", "version": "1.0.0"},
) )
self.send_api_response(response) self.send_api_response(response)
def _send_methods_response(self) -> None: def _send_methods_response(self) -> None:
"""Send available API methods response.""" """Send available API methods response."""
response = self._create_success_response( response = self._create_success_response(
"methods", {"methods": list(self.api.functions.keys())} "methods",
{"methods": list(self.api.functions.keys())},
) )
self.send_api_response(response) self.send_api_response(response)
@@ -179,7 +185,7 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
json_data = json.loads(file_data.decode("utf-8")) json_data = json.loads(file_data.decode("utf-8"))
server_address = getattr(self.server, "server_address", ("localhost", 80)) server_address = getattr(self.server, "server_address", ("localhost", 80))
json_data["servers"] = [ json_data["servers"] = [
{"url": f"http://{server_address[0]}:{server_address[1]}/api/v1/"} {"url": f"http://{server_address[0]}:{server_address[1]}/api/v1/"},
] ]
file_data = json.dumps(json_data, indent=2).encode("utf-8") file_data = json.dumps(json_data, indent=2).encode("utf-8")
@@ -213,7 +219,9 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
# Validate API path # Validate API path
if not path.startswith("/api/v1/"): if not path.startswith("/api/v1/"):
self.send_api_error_response( self.send_api_error_response(
"post", f"Path not found: {path}", ["http_bridge", "POST"] "post",
f"Path not found: {path}",
["http_bridge", "POST"],
) )
return return
@@ -221,7 +229,9 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
method_name = path[len("/api/v1/") :] method_name = path[len("/api/v1/") :]
if not method_name: if not method_name:
self.send_api_error_response( self.send_api_error_response(
"post", "Method name required", ["http_bridge", "POST"] "post",
"Method name required",
["http_bridge", "POST"],
) )
return return
@@ -289,19 +299,26 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
# Create API request # Create API request
api_request = BackendRequest( api_request = BackendRequest(
method_name=method_name, args=body, header=header, op_key=op_key method_name=method_name,
args=body,
header=header,
op_key=op_key,
) )
except Exception as e: except Exception as e:
self.send_api_error_response( self.send_api_error_response(
gen_op_key, str(e), ["http_bridge", method_name] gen_op_key,
str(e),
["http_bridge", method_name],
) )
return return
self._process_api_request_in_thread(api_request, method_name) self._process_api_request_in_thread(api_request, method_name)
def _parse_request_data( def _parse_request_data(
self, request_data: dict[str, Any], gen_op_key: str self,
request_data: dict[str, Any],
gen_op_key: str,
) -> tuple[dict[str, Any], dict[str, Any], str]: ) -> tuple[dict[str, Any], dict[str, Any], str]:
"""Parse and validate request data components.""" """Parse and validate request data components."""
header = request_data.get("header", {}) header = request_data.get("header", {})
@@ -344,7 +361,9 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
pass pass
def _process_api_request_in_thread( def _process_api_request_in_thread(
self, api_request: BackendRequest, method_name: str self,
api_request: BackendRequest,
method_name: str,
) -> None: ) -> None:
"""Process the API request in a separate thread.""" """Process the API request in a separate thread."""
stop_event = threading.Event() stop_event = threading.Event()
@@ -358,7 +377,7 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
log.debug( log.debug(
f"Processing {request.method_name} with args {request.args} " f"Processing {request.method_name} with args {request.args} "
f"and header {request.header}" f"and header {request.header}",
) )
self.process_request(request) self.process_request(request)

View File

@@ -64,7 +64,8 @@ def mock_log_manager() -> Mock:
@pytest.fixture @pytest.fixture
def http_bridge( def http_bridge(
mock_api: MethodRegistry, mock_log_manager: Mock mock_api: MethodRegistry,
mock_log_manager: Mock,
) -> tuple[MethodRegistry, tuple]: ) -> tuple[MethodRegistry, tuple]:
"""Create HTTP bridge dependencies for testing.""" """Create HTTP bridge dependencies for testing."""
middleware_chain = ( middleware_chain = (
@@ -256,7 +257,9 @@ class TestIntegration:
"""Integration tests for HTTP API components.""" """Integration tests for HTTP API components."""
def test_full_request_flow( def test_full_request_flow(
self, mock_api: MethodRegistry, mock_log_manager: Mock self,
mock_api: MethodRegistry,
mock_log_manager: Mock,
) -> None: ) -> None:
"""Test complete request flow from server to bridge to middleware.""" """Test complete request flow from server to bridge to middleware."""
server: HttpApiServer = HttpApiServer( server: HttpApiServer = HttpApiServer(
@@ -301,7 +304,9 @@ class TestIntegration:
server.stop() server.stop()
def test_blocking_task( def test_blocking_task(
self, mock_api: MethodRegistry, mock_log_manager: Mock self,
mock_api: MethodRegistry,
mock_log_manager: Mock,
) -> None: ) -> None:
shared_threads: dict[str, tasks.WebThread] = {} shared_threads: dict[str, tasks.WebThread] = {}
tasks.BAKEND_THREADS = shared_threads tasks.BAKEND_THREADS = shared_threads

View File

@@ -36,7 +36,6 @@ def _get_lib_names() -> list[str]:
def _be_sure_libraries() -> list[Path] | None: def _be_sure_libraries() -> list[Path] | None:
"""Ensure libraries exist and return paths.""" """Ensure libraries exist and return paths."""
lib_dir = os.environ.get("WEBVIEW_LIB_DIR") lib_dir = os.environ.get("WEBVIEW_LIB_DIR")
if not lib_dir: if not lib_dir:
msg = "WEBVIEW_LIB_DIR environment variable is not set" msg = "WEBVIEW_LIB_DIR environment variable is not set"

View File

@@ -144,7 +144,9 @@ class Webview:
) )
else: else:
bridge = WebviewBridge( bridge = WebviewBridge(
webview=self, middleware_chain=tuple(self._middleware), threads={} webview=self,
middleware_chain=tuple(self._middleware),
threads={},
) )
self._bridge = bridge self._bridge = bridge
@@ -154,7 +156,10 @@ class Webview:
def set_size(self, value: Size) -> None: def set_size(self, value: Size) -> None:
"""Set the webview size (legacy compatibility).""" """Set the webview size (legacy compatibility)."""
_webview_lib.webview_set_size( _webview_lib.webview_set_size(
self.handle, value.width, value.height, value.hint self.handle,
value.width,
value.height,
value.hint,
) )
def set_title(self, value: str) -> None: def set_title(self, value: str) -> None:
@@ -194,7 +199,10 @@ class Webview:
self._callbacks[name] = c_callback self._callbacks[name] = c_callback
_webview_lib.webview_bind( _webview_lib.webview_bind(
self.handle, _encode_c_string(name), c_callback, None self.handle,
_encode_c_string(name),
c_callback,
None,
) )
def bind(self, name: str, callback: Callable[..., Any]) -> None: def bind(self, name: str, callback: Callable[..., Any]) -> None:
@@ -219,7 +227,10 @@ class Webview:
def return_(self, seq: str, status: int, result: str) -> None: def return_(self, seq: str, status: int, result: str) -> None:
_webview_lib.webview_return( _webview_lib.webview_return(
self.handle, _encode_c_string(seq), status, _encode_c_string(result) self.handle,
_encode_c_string(seq),
status,
_encode_c_string(result),
) )
def eval(self, source: str) -> None: def eval(self, source: str) -> None:

View File

@@ -26,7 +26,9 @@ class WebviewBridge(ApiBridge):
def send_api_response(self, response: BackendResponse) -> None: def send_api_response(self, response: BackendResponse) -> None:
"""Send response back to the webview client.""" """Send response back to the webview client."""
serialized = json.dumps( serialized = json.dumps(
dataclass_to_dict(response), indent=4, ensure_ascii=False dataclass_to_dict(response),
indent=4,
ensure_ascii=False,
) )
log.debug(f"Sending response: {serialized}") log.debug(f"Sending response: {serialized}")
@@ -40,7 +42,6 @@ class WebviewBridge(ApiBridge):
arg: int, arg: int,
) -> None: ) -> None:
"""Handle a call from webview's JavaScript bridge.""" """Handle a call from webview's JavaScript bridge."""
try: try:
op_key = op_key_bytes.decode() op_key = op_key_bytes.decode()
raw_args = json.loads(request_data.decode()) raw_args = json.loads(request_data.decode())
@@ -68,7 +69,10 @@ class WebviewBridge(ApiBridge):
# Create API request # Create API request
api_request = BackendRequest( api_request = BackendRequest(
method_name=method_name, args=args, header=header, op_key=op_key method_name=method_name,
args=args,
header=header,
op_key=op_key,
) )
except Exception as e: except Exception as e:
@@ -77,7 +81,9 @@ class WebviewBridge(ApiBridge):
) )
log.exception(msg) log.exception(msg)
self.send_api_error_response( self.send_api_error_response(
op_key, str(e), ["webview_bridge", method_name] op_key,
str(e),
["webview_bridge", method_name],
) )
return return

View File

@@ -54,8 +54,7 @@ class Command:
@pytest.fixture @pytest.fixture
def command() -> Iterator[Command]: def command() -> Iterator[Command]:
""" """Starts a background command. The process is automatically terminated in the end.
Starts a background command. The process is automatically terminated in the end.
>>> p = command.run(["some", "daemon"]) >>> p = command.run(["some", "daemon"])
>>> print(p.pid) >>> print(p.pid)
""" """

View File

@@ -13,23 +13,17 @@ else:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def project_root() -> Path: def project_root() -> Path:
""" """Root directory the clan-cli"""
Root directory the clan-cli
"""
return PROJECT_ROOT return PROJECT_ROOT
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_root() -> Path: def test_root() -> Path:
""" """Root directory of the tests"""
Root directory of the tests
"""
return TEST_ROOT return TEST_ROOT
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def clan_core() -> Path: def clan_core() -> Path:
""" """Directory of the clan-core flake"""
Directory of the clan-core flake
"""
return CLAN_CORE return CLAN_CORE

View File

@@ -24,7 +24,11 @@ def app() -> Generator[GtkProc]:
cmd = [sys.executable, "-m", "clan_app"] cmd = [sys.executable, "-m", "clan_app"]
print(f"Running: {cmd}") print(f"Running: {cmd}")
rapp = Popen( rapp = Popen(
cmd, text=True, stdout=sys.stdout, stderr=sys.stderr, start_new_session=True cmd,
text=True,
stdout=sys.stdout,
stderr=sys.stderr,
start_new_session=True,
) )
yield GtkProc(rapp) yield GtkProc(rapp)
# Cleanup: Terminate your application # Cleanup: Terminate your application

View File

@@ -22,12 +22,16 @@ def create_command(args: argparse.Namespace) -> None:
def register_create_parser(parser: argparse.ArgumentParser) -> None: def register_create_parser(parser: argparse.ArgumentParser) -> None:
machines_parser = parser.add_argument( machines_parser = parser.add_argument(
"machine", type=str, help="machine in the flake to create backups of" "machine",
type=str,
help="machine in the flake to create backups of",
) )
add_dynamic_completer(machines_parser, complete_machines) add_dynamic_completer(machines_parser, complete_machines)
provider_action = parser.add_argument( provider_action = parser.add_argument(
"--provider", type=str, help="backup provider to use" "--provider",
type=str,
help="backup provider to use",
) )
add_dynamic_completer(provider_action, complete_backup_providers_for_machine) add_dynamic_completer(provider_action, complete_backup_providers_for_machine)
parser.set_defaults(func=create_command) parser.set_defaults(func=create_command)

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_create_command_no_flake( def test_create_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -21,11 +21,15 @@ def list_command(args: argparse.Namespace) -> None:
def register_list_parser(parser: argparse.ArgumentParser) -> None: def register_list_parser(parser: argparse.ArgumentParser) -> None:
machines_parser = parser.add_argument( machines_parser = parser.add_argument(
"machine", type=str, help="machine in the flake to show backups of" "machine",
type=str,
help="machine in the flake to show backups of",
) )
add_dynamic_completer(machines_parser, complete_machines) add_dynamic_completer(machines_parser, complete_machines)
provider_action = parser.add_argument( provider_action = parser.add_argument(
"--provider", type=str, help="backup provider to filter by" "--provider",
type=str,
help="backup provider to filter by",
) )
add_dynamic_completer(provider_action, complete_backup_providers_for_machine) add_dynamic_completer(provider_action, complete_backup_providers_for_machine)
parser.set_defaults(func=list_command) parser.set_defaults(func=list_command)

View File

@@ -24,11 +24,15 @@ def restore_command(args: argparse.Namespace) -> None:
def register_restore_parser(parser: argparse.ArgumentParser) -> None: def register_restore_parser(parser: argparse.ArgumentParser) -> None:
machine_action = parser.add_argument( machine_action = parser.add_argument(
"machine", type=str, help="machine in the flake to create backups of" "machine",
type=str,
help="machine in the flake to create backups of",
) )
add_dynamic_completer(machine_action, complete_machines) add_dynamic_completer(machine_action, complete_machines)
provider_action = parser.add_argument( provider_action = parser.add_argument(
"provider", type=str, help="backup provider to use" "provider",
type=str,
help="backup provider to use",
) )
add_dynamic_completer(provider_action, complete_backup_providers_for_machine) add_dynamic_completer(provider_action, complete_backup_providers_for_machine)
parser.add_argument("name", type=str, help="Name of the backup to restore") parser.add_argument("name", type=str, help="Name of the backup to restore")

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_restore_command_no_flake( def test_restore_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -67,7 +67,7 @@ def register_create_parser(parser: argparse.ArgumentParser) -> None:
setup_git=not args.no_git, setup_git=not args.no_git,
src_flake=args.flake, src_flake=args.flake,
update_clan=not args.no_update, update_clan=not args.no_update,
) ),
) )
create_secrets_user_auto( create_secrets_user_auto(
flake_dir=Path(args.name).resolve(), flake_dir=Path(args.name).resolve(),

View File

@@ -74,8 +74,8 @@ def inspect_flake(flake_url: str | Path, machine_name: str) -> FlakeConfig:
# Get the Clan name # Get the Clan name
cmd = nix_eval( cmd = nix_eval(
[ [
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.name' f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.name',
] ],
) )
res = run_cmd(cmd) res = run_cmd(cmd)
clan_name = res.strip('"') clan_name = res.strip('"')
@@ -83,8 +83,8 @@ def inspect_flake(flake_url: str | Path, machine_name: str) -> FlakeConfig:
# Get the clan icon path # Get the clan icon path
cmd = nix_eval( cmd = nix_eval(
[ [
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.icon' f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.icon',
] ],
) )
res = run_cmd(cmd) res = run_cmd(cmd)
@@ -96,7 +96,7 @@ def inspect_flake(flake_url: str | Path, machine_name: str) -> FlakeConfig:
cmd = nix_build( cmd = nix_build(
[ [
f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.icon' f'{flake_url}#clanInternals.machines."{system}"."{machine_name}".config.clan.core.icon',
], ],
machine_gcroot(flake_url=str(flake_url)) / "icon", machine_gcroot(flake_url=str(flake_url)) / "icon",
) )
@@ -129,7 +129,8 @@ def inspect_command(args: argparse.Namespace) -> None:
flake=args.flake or Flake(str(Path.cwd())), flake=args.flake or Flake(str(Path.cwd())),
) )
res = inspect_flake( res = inspect_flake(
flake_url=str(inspect_options.flake), machine_name=inspect_options.machine flake_url=str(inspect_options.flake),
machine_name=inspect_options.machine,
) )
print("Clan name:", res.clan_name) print("Clan name:", res.clan_name)
print("Icon:", res.icon) print("Icon:", res.icon)

View File

@@ -10,7 +10,8 @@ from clan_cli.tests.stdout import CaptureOutput
@pytest.mark.with_core @pytest.mark.with_core
def test_clan_show( def test_clan_show(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run(["show", "--flake", str(test_flake_with_core.path)]) cli.run(["show", "--flake", str(test_flake_with_core.path)])
@@ -20,7 +21,9 @@ def test_clan_show(
def test_clan_show_no_flake( def test_clan_show_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capture_output: CaptureOutput tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
capture_output: CaptureOutput,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)
@@ -28,8 +31,8 @@ def test_clan_show_no_flake(
cli.run(["show"]) cli.run(["show"])
assert "No clan flake found in the current directory or its parents" in str( assert "No clan flake found in the current directory or its parents" in str(
exc_info.value exc_info.value,
) )
assert "Use the --flake flag to specify a clan flake path or URL" in str( assert "Use the --flake flag to specify a clan flake path or URL" in str(
exc_info.value exc_info.value,
) )

View File

@@ -52,8 +52,7 @@ def create_flake_from_args(args: argparse.Namespace) -> Flake:
def add_common_flags(parser: argparse.ArgumentParser) -> None: def add_common_flags(parser: argparse.ArgumentParser) -> None:
def argument_exists(parser: argparse.ArgumentParser, arg: str) -> bool: def argument_exists(parser: argparse.ArgumentParser, arg: str) -> bool:
""" """Check if an argparse argument already exists.
Check if an argparse argument already exists.
This is needed because the aliases subcommand doesn't *really* This is needed because the aliases subcommand doesn't *really*
create an alias - it duplicates the actual parser in the tree create an alias - it duplicates the actual parser in the tree
making duplication inevitable while naively traversing. making duplication inevitable while naively traversing.
@@ -410,7 +409,9 @@ For more detailed information, visit: {help_hyperlink("deploy", "https://docs.cl
machines.register_parser(parser_machine) machines.register_parser(parser_machine)
parser_vms = subparsers.add_parser( parser_vms = subparsers.add_parser(
"vms", help="Manage virtual machines", description="Manage virtual machines" "vms",
help="Manage virtual machines",
description="Manage virtual machines",
) )
vms.register_parser(parser_vms) vms.register_parser(parser_vms)

View File

@@ -38,11 +38,11 @@ def clan_dir(flake: str | None) -> str | None:
def complete_machines( def complete_machines(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for machine names configured in the clan."""
Provides completion functionality for machine names configured in the clan.
"""
machines: list[str] = [] machines: list[str] = []
def run_cmd() -> None: def run_cmd() -> None:
@@ -72,11 +72,11 @@ def complete_machines(
def complete_services_for_machine( def complete_services_for_machine(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for machine facts generation services."""
Provides completion functionality for machine facts generation services.
"""
services: list[str] = [] services: list[str] = []
# TODO: consolidate, if multiple machines are used # TODO: consolidate, if multiple machines are used
machines: list[str] = parsed_args.machines machines: list[str] = parsed_args.machines
@@ -98,7 +98,7 @@ def complete_services_for_machine(
"builtins.attrNames", "builtins.attrNames",
], ],
), ),
).stdout.strip() ).stdout.strip(),
) )
services.extend(services_result) services.extend(services_result)
@@ -117,11 +117,11 @@ def complete_services_for_machine(
def complete_backup_providers_for_machine( def complete_backup_providers_for_machine(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for machine backup providers."""
Provides completion functionality for machine backup providers.
"""
providers: list[str] = [] providers: list[str] = []
machine: str = parsed_args.machine machine: str = parsed_args.machine
@@ -142,7 +142,7 @@ def complete_backup_providers_for_machine(
"builtins.attrNames", "builtins.attrNames",
], ],
), ),
).stdout.strip() ).stdout.strip(),
) )
providers.extend(providers_result) providers.extend(providers_result)
@@ -161,11 +161,11 @@ def complete_backup_providers_for_machine(
def complete_state_services_for_machine( def complete_state_services_for_machine(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for machine state providers."""
Provides completion functionality for machine state providers.
"""
providers: list[str] = [] providers: list[str] = []
machine: str = parsed_args.machine machine: str = parsed_args.machine
@@ -186,7 +186,7 @@ def complete_state_services_for_machine(
"builtins.attrNames", "builtins.attrNames",
], ],
), ),
).stdout.strip() ).stdout.strip(),
) )
providers.extend(providers_result) providers.extend(providers_result)
@@ -205,11 +205,11 @@ def complete_state_services_for_machine(
def complete_secrets( def complete_secrets(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for clan secrets"""
Provides completion functionality for clan secrets
"""
from clan_lib.flake.flake import Flake from clan_lib.flake.flake import Flake
from .secrets.secrets import list_secrets from .secrets.secrets import list_secrets
@@ -228,11 +228,11 @@ def complete_secrets(
def complete_users( def complete_users(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for clan users"""
Provides completion functionality for clan users
"""
from pathlib import Path from pathlib import Path
from .secrets.users import list_users from .secrets.users import list_users
@@ -251,11 +251,11 @@ def complete_users(
def complete_groups( def complete_groups(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for clan groups"""
Provides completion functionality for clan groups
"""
from pathlib import Path from pathlib import Path
from .secrets.groups import list_groups from .secrets.groups import list_groups
@@ -275,12 +275,11 @@ def complete_groups(
def complete_templates_disko( def complete_templates_disko(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for disko templates"""
Provides completion functionality for disko templates
"""
from clan_lib.templates import list_templates from clan_lib.templates import list_templates
flake = ( flake = (
@@ -300,12 +299,11 @@ def complete_templates_disko(
def complete_templates_clan( def complete_templates_clan(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for clan templates"""
Provides completion functionality for clan templates
"""
from clan_lib.templates import list_templates from clan_lib.templates import list_templates
flake = ( flake = (
@@ -325,10 +323,11 @@ def complete_templates_clan(
def complete_vars_for_machine( def complete_vars_for_machine(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for variable names for a specific machine.
Provides completion functionality for variable names for a specific machine.
Only completes vars that already exist in the vars directory on disk. Only completes vars that already exist in the vars directory on disk.
This is fast as it only scans the filesystem without any evaluation. This is fast as it only scans the filesystem without any evaluation.
""" """
@@ -368,11 +367,11 @@ def complete_vars_for_machine(
def complete_target_host( def complete_target_host(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for target_host for a specific machine"""
Provides completion functionality for target_host for a specific machine
"""
target_hosts: list[str] = [] target_hosts: list[str] = []
machine: str = parsed_args.machine machine: str = parsed_args.machine
@@ -391,7 +390,7 @@ def complete_target_host(
f"{flake}#nixosConfigurations.{machine}.config.clan.core.networking.targetHost", f"{flake}#nixosConfigurations.{machine}.config.clan.core.networking.targetHost",
], ],
), ),
).stdout.strip() ).stdout.strip(),
) )
target_hosts.append(target_host_result) target_hosts.append(target_host_result)
@@ -410,11 +409,11 @@ def complete_target_host(
def complete_tags( def complete_tags(
prefix: str, parsed_args: argparse.Namespace, **kwargs: Any prefix: str,
parsed_args: argparse.Namespace,
**kwargs: Any,
) -> Iterable[str]: ) -> Iterable[str]:
""" """Provides completion functionality for tags inside the inventory"""
Provides completion functionality for tags inside the inventory
"""
tags: list[str] = [] tags: list[str] = []
threads = [] threads = []
@@ -483,8 +482,7 @@ def add_dynamic_completer(
action: argparse.Action, action: argparse.Action,
completer: Callable[..., Iterable[str]], completer: Callable[..., Iterable[str]],
) -> None: ) -> None:
""" """Add a completion function to an argparse action, this will only be added,
Add a completion function to an argparse action, this will only be added,
if the argcomplete module is loaded. if the argcomplete module is loaded.
""" """
if argcomplete: if argcomplete:

View File

@@ -21,14 +21,14 @@ def check_secrets(machine: Machine, service: None | str = None) -> bool:
secret_name = secret_fact["name"] secret_name = secret_fact["name"]
if not machine.secret_facts_store.exists(service, secret_name): if not machine.secret_facts_store.exists(service, secret_name):
machine.info( machine.info(
f"Secret fact '{secret_fact}' for service '{service}' is missing." f"Secret fact '{secret_fact}' for service '{service}' is missing.",
) )
missing_secret_facts.append((service, secret_name)) missing_secret_facts.append((service, secret_name))
for public_fact in machine.facts_data[service]["public"]: for public_fact in machine.facts_data[service]["public"]:
if not machine.public_facts_store.exists(service, public_fact): if not machine.public_facts_store.exists(service, public_fact):
machine.info( machine.info(
f"Public fact '{public_fact}' for service '{service}' is missing." f"Public fact '{public_fact}' for service '{service}' is missing.",
) )
missing_public_facts.append((service, public_fact)) missing_public_facts.append((service, public_fact))

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_check_command_no_flake( def test_check_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -29,9 +29,7 @@ log = logging.getLogger(__name__)
def read_multiline_input(prompt: str = "Finish with Ctrl-D") -> str: def read_multiline_input(prompt: str = "Finish with Ctrl-D") -> str:
""" """Read multi-line input from stdin."""
Read multi-line input from stdin.
"""
print(prompt, flush=True) print(prompt, flush=True)
proc = run(["cat"], RunOpts(check=False)) proc = run(["cat"], RunOpts(check=False))
log.info("Input received. Processing...") log.info("Input received. Processing...")
@@ -63,7 +61,7 @@ def bubblewrap_cmd(generator: str, facts_dir: Path, secrets_dir: Path) -> list[s
"--uid", "1000", "--uid", "1000",
"--gid", "1000", "--gid", "1000",
"--", "--",
"bash", "-c", generator "bash", "-c", generator,
], ],
) )
# fmt: on # fmt: on
@@ -102,7 +100,8 @@ def generate_service_facts(
generator = machine.facts_data[service]["generator"]["finalScript"] generator = machine.facts_data[service]["generator"]["finalScript"]
if machine.facts_data[service]["generator"]["prompt"]: if machine.facts_data[service]["generator"]["prompt"]:
prompt_value = prompt( prompt_value = prompt(
service, machine.facts_data[service]["generator"]["prompt"] service,
machine.facts_data[service]["generator"]["prompt"],
) )
env["prompt_value"] = prompt_value env["prompt_value"] = prompt_value
from clan_lib import bwrap from clan_lib import bwrap
@@ -126,7 +125,10 @@ def generate_service_facts(
msg += generator msg += generator
raise ClanError(msg) raise ClanError(msg)
secret_path = secret_facts_store.set( secret_path = secret_facts_store.set(
service, secret_name, secret_file.read_bytes(), groups service,
secret_name,
secret_file.read_bytes(),
groups,
) )
if secret_path: if secret_path:
files_to_commit.append(secret_path) files_to_commit.append(secret_path)
@@ -206,7 +208,11 @@ def generate_facts(
errors = 0 errors = 0
try: try:
was_regenerated |= _generate_facts_for_machine( was_regenerated |= _generate_facts_for_machine(
machine, service, regenerate, tmpdir, prompt machine,
service,
regenerate,
tmpdir,
prompt,
) )
except (OSError, ClanError) as e: except (OSError, ClanError) as e:
machine.error(f"Failed to generate facts: {e}") machine.error(f"Failed to generate facts: {e}")
@@ -231,7 +237,7 @@ def generate_command(args: argparse.Namespace) -> None:
filter( filter(
lambda m: m.name in args.machines, lambda m: m.name in args.machines,
machines, machines,
) ),
) )
generate_facts(machines, args.service, args.regenerate) generate_facts(machines, args.service, args.regenerate)

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_generate_command_no_flake( def test_generate_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
import clan_lib.machines.machines as machines from clan_lib.machines import machines
class FactStoreBase(ABC): class FactStoreBase(ABC):

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
import clan_lib.machines.machines as machines from clan_lib.machines import machines
from clan_lib.ssh.host import Host from clan_lib.ssh.host import Host
@@ -14,7 +14,11 @@ class SecretStoreBase(ABC):
@abstractmethod @abstractmethod
def set( def set(
self, service: str, name: str, value: bytes, groups: list[str] self,
service: str,
name: str,
value: bytes,
groups: list[str],
) -> Path | None: ) -> Path | None:
pass pass

View File

@@ -16,7 +16,11 @@ class SecretStore(SecretStoreBase):
self.machine = machine self.machine = machine
def set( def set(
self, service: str, name: str, value: bytes, groups: list[str] self,
service: str,
name: str,
value: bytes,
groups: list[str],
) -> Path | None: ) -> Path | None:
subprocess.run( subprocess.run(
nix_shell( nix_shell(
@@ -40,14 +44,16 @@ class SecretStore(SecretStoreBase):
def exists(self, service: str, name: str) -> bool: def exists(self, service: str, name: str) -> bool:
password_store = os.environ.get( password_store = os.environ.get(
"PASSWORD_STORE_DIR", f"{os.environ['HOME']}/.password-store" "PASSWORD_STORE_DIR",
f"{os.environ['HOME']}/.password-store",
) )
secret_path = Path(password_store) / f"machines/{self.machine.name}/{name}.gpg" secret_path = Path(password_store) / f"machines/{self.machine.name}/{name}.gpg"
return secret_path.exists() return secret_path.exists()
def generate_hash(self) -> bytes: def generate_hash(self) -> bytes:
password_store = os.environ.get( password_store = os.environ.get(
"PASSWORD_STORE_DIR", f"{os.environ['HOME']}/.password-store" "PASSWORD_STORE_DIR",
f"{os.environ['HOME']}/.password-store",
) )
hashes = [] hashes = []
hashes.append( hashes.append(
@@ -66,7 +72,7 @@ class SecretStore(SecretStoreBase):
), ),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
check=False, check=False,
).stdout.strip() ).stdout.strip(),
) )
for symlink in Path(password_store).glob(f"machines/{self.machine.name}/**/*"): for symlink in Path(password_store).glob(f"machines/{self.machine.name}/**/*"):
if symlink.is_symlink(): if symlink.is_symlink():
@@ -86,7 +92,7 @@ class SecretStore(SecretStoreBase):
), ),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
check=False, check=False,
).stdout.strip() ).stdout.strip(),
) )
# we sort the hashes to make sure that the order is always the same # we sort the hashes to make sure that the order is always the same

View File

@@ -37,7 +37,11 @@ class SecretStore(SecretStoreBase):
add_machine(self.machine.flake_dir, self.machine.name, pub_key, False) add_machine(self.machine.flake_dir, self.machine.name, pub_key, False)
def set( def set(
self, service: str, name: str, value: bytes, groups: list[str] self,
service: str,
name: str,
value: bytes,
groups: list[str],
) -> Path | None: ) -> Path | None:
path = ( path = (
sops_secrets_folder(self.machine.flake_dir) / f"{self.machine.name}-{name}" sops_secrets_folder(self.machine.flake_dir) / f"{self.machine.name}-{name}"

View File

@@ -15,7 +15,11 @@ class SecretStore(SecretStoreBase):
self.dir.mkdir(parents=True, exist_ok=True) self.dir.mkdir(parents=True, exist_ok=True)
def set( def set(
self, service: str, name: str, value: bytes, groups: list[str] self,
service: str,
name: str,
value: bytes,
groups: list[str],
) -> Path | None: ) -> Path | None:
secret_file = self.dir / service / name secret_file = self.dir / service / name
secret_file.parent.mkdir(parents=True, exist_ok=True) secret_file.parent.mkdir(parents=True, exist_ok=True)

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_upload_command_no_flake( def test_upload_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -21,6 +21,7 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
register_flash_write_parser(write_parser) register_flash_write_parser(write_parser)
list_parser = subparser.add_parser( list_parser = subparser.add_parser(
"list", help="List possible keymaps or languages" "list",
help="List possible keymaps or languages",
) )
register_flash_list_parser(list_parser) register_flash_list_parser(list_parser)

View File

@@ -121,7 +121,7 @@ def register_flash_write_parser(parser: argparse.ArgumentParser) -> None:
Format will format the disk before installing. Format will format the disk before installing.
Mount will mount the disk before installing. Mount will mount the disk before installing.
Mount is useful for updating an existing system without losing data. Mount is useful for updating an existing system without losing data.
""" """,
) )
parser.add_argument( parser.add_argument(
"--mode", "--mode",
@@ -166,7 +166,7 @@ def register_flash_write_parser(parser: argparse.ArgumentParser) -> None:
Write EFI boot entries to the NVRAM of the system for the installed system. Write EFI boot entries to the NVRAM of the system for the installed system.
Specify this option if you plan to boot from this disk on the current machine, Specify this option if you plan to boot from this disk on the current machine,
but not if you plan to move the disk to another machine. but not if you plan to move the disk to another machine.
""" """,
).strip(), ).strip(),
default=False, default=False,
action="store_true", action="store_true",

View File

@@ -8,7 +8,8 @@ from clan_cli.tests.stdout import CaptureOutput
@pytest.mark.with_core @pytest.mark.with_core
def test_flash_list_languages( def test_flash_list_languages(
temporary_home: Path, capture_output: CaptureOutput temporary_home: Path,
capture_output: CaptureOutput,
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run(["flash", "list", "languages"]) cli.run(["flash", "list", "languages"])
@@ -20,7 +21,8 @@ def test_flash_list_languages(
@pytest.mark.with_core @pytest.mark.with_core
def test_flash_list_keymaps( def test_flash_list_keymaps(
temporary_home: Path, capture_output: CaptureOutput temporary_home: Path,
capture_output: CaptureOutput,
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run(["flash", "list", "keymaps"]) cli.run(["flash", "list", "keymaps"])

View File

@@ -1,7 +1,6 @@
# Implementation of OSC8 # Implementation of OSC8
def hyperlink(text: str, url: str) -> str: def hyperlink(text: str, url: str) -> str:
""" """Generate OSC8 escape sequence for hyperlinks.
Generate OSC8 escape sequence for hyperlinks.
Args: Args:
url (str): The URL to link to. url (str): The URL to link to.
@@ -9,15 +8,14 @@ def hyperlink(text: str, url: str) -> str:
Returns: Returns:
str: The formatted string with an embedded hyperlink. str: The formatted string with an embedded hyperlink.
""" """
esc = "\033" esc = "\033"
return f"{esc}]8;;{url}{esc}\\{text}{esc}]8;;{esc}\\" return f"{esc}]8;;{url}{esc}\\{text}{esc}]8;;{esc}\\"
def hyperlink_same_text_and_url(url: str) -> str: def hyperlink_same_text_and_url(url: str) -> str:
""" """Keep the description and the link the same to support legacy terminals."""
Keep the description and the link the same to support legacy terminals.
"""
return hyperlink(url, url) return hyperlink(url, url)
@@ -34,9 +32,7 @@ def help_hyperlink(description: str, url: str) -> str:
def docs_hyperlink(description: str, url: str) -> str: def docs_hyperlink(description: str, url: str) -> str:
""" """Returns a markdown hyperlink"""
Returns a markdown hyperlink
"""
url = url.replace("https://docs.clan.lol", "../..") url = url.replace("https://docs.clan.lol", "../..")
url = url.replace("index.html", "index") url = url.replace("index.html", "index")
url += ".md" url += ".md"

View File

@@ -32,8 +32,7 @@ def create_machine(
opts: CreateOptions, opts: CreateOptions,
commit: bool = True, commit: bool = True,
) -> None: ) -> None:
""" """Create a new machine in the clan directory.
Create a new machine in the clan directory.
This function will create a new machine based on a template. This function will create a new machine based on a template.
@@ -41,7 +40,6 @@ def create_machine(
:param commit: Whether to commit the changes to the git repository. :param commit: Whether to commit the changes to the git repository.
:param _persist: Temporary workaround for 'morph'. Whether to persist the changes to the inventory store. :param _persist: Temporary workaround for 'morph'. Whether to persist the changes to the inventory store.
""" """
if not opts.clan_dir.is_local: if not opts.clan_dir.is_local:
msg = f"Clan {opts.clan_dir} is not a local clan." msg = f"Clan {opts.clan_dir} is not a local clan."
description = "Import machine only works on local clans" description = "Import machine only works on local clans"

View File

@@ -33,13 +33,15 @@ def update_hardware_config_command(args: argparse.Namespace) -> None:
if args.target_host: if args.target_host:
target_host = Remote.from_ssh_uri( target_host = Remote.from_ssh_uri(
machine_name=machine.name, address=args.target_host machine_name=machine.name,
address=args.target_host,
) )
else: else:
target_host = machine.target_host() target_host = machine.target_host()
target_host = target_host.override( target_host = target_host.override(
host_key_check=args.host_key_check, private_key=args.identity_file host_key_check=args.host_key_check,
private_key=args.identity_file,
) )
run_machine_hardware_info(opts, target_host) run_machine_hardware_info(opts, target_host)

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.helpers import cli
def test_create_command_no_flake( def test_create_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -34,7 +34,8 @@ def install_command(args: argparse.Namespace) -> None:
if args.target_host: if args.target_host:
# TODO add network support here with either --network or some url magic # TODO add network support here with either --network or some url magic
remote = Remote.from_ssh_uri( remote = Remote.from_ssh_uri(
machine_name=args.machine, address=args.target_host machine_name=args.machine,
address=args.target_host,
) )
elif args.png: elif args.png:
data = read_qr_image(Path(args.png)) data = read_qr_image(Path(args.png))
@@ -73,7 +74,7 @@ def install_command(args: argparse.Namespace) -> None:
if ask == "n" or ask == "": if ask == "n" or ask == "":
return None return None
print( print(
f"Invalid input '{ask}'. Please enter 'y' for yes or 'n' for no." f"Invalid input '{ask}'. Please enter 'y' for yes or 'n' for no.",
) )
if args.identity_file: if args.identity_file:

View File

@@ -13,7 +13,8 @@ def list_command(args: argparse.Namespace) -> None:
flake = require_flake(args.flake) flake = require_flake(args.flake)
for name in list_machines( for name in list_machines(
flake, opts=ListOptions(filter=MachineFilter(tags=args.tags)) flake,
opts=ListOptions(filter=MachineFilter(tags=args.tags)),
): ):
print(name) print(name)

View File

@@ -43,7 +43,7 @@ def list_basic(
description = "Backup server"; description = "Backup server";
}; };
}; };
}""" }""",
}, },
], ],
indirect=True, indirect=True,
@@ -62,7 +62,7 @@ def list_with_tags_single_tag(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--tags", "--tags",
"production", "production",
] ],
) )
assert "web-server" in output.out assert "web-server" in output.out
@@ -94,7 +94,7 @@ def list_with_tags_single_tag(
description = "Backup server"; description = "Backup server";
}; };
}; };
}""" }""",
}, },
], ],
indirect=True, indirect=True,
@@ -114,7 +114,7 @@ def list_with_tags_multiple_tags_intersection(
"--tags", "--tags",
"web", "web",
"production", "production",
] ],
) )
# Should only include machines that have BOTH tags (intersection) # Should only include machines that have BOTH tags (intersection)
@@ -139,7 +139,7 @@ def test_machines_list_with_tags_no_matches(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--tags", "--tags",
"nonexistent", "nonexistent",
] ],
) )
assert output.out.strip() == "" assert output.out.strip() == ""
@@ -162,7 +162,7 @@ def test_machines_list_with_tags_no_matches(
}; };
server4 = { }; server4 = { };
}; };
}""" }""",
}, },
], ],
indirect=True, indirect=True,
@@ -180,7 +180,7 @@ def list_with_tags_various_scenarios(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--tags", "--tags",
"web", "web",
] ],
) )
assert "server1" in output.out assert "server1" in output.out
@@ -197,7 +197,7 @@ def list_with_tags_various_scenarios(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--tags", "--tags",
"database", "database",
] ],
) )
assert "server2" in output.out assert "server2" in output.out
@@ -216,7 +216,7 @@ def list_with_tags_various_scenarios(
"--tags", "--tags",
"web", "web",
"database", "database",
] ],
) )
assert "server3" in output.out assert "server3" in output.out
@@ -239,7 +239,7 @@ def created_machine_and_tags(
"--tags", "--tags",
"test", "test",
"server", "server",
] ],
) )
with capture_output as output: with capture_output as output:
@@ -258,7 +258,7 @@ def created_machine_and_tags(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--tags", "--tags",
"test", "test",
] ],
) )
assert "test-machine" in output.out assert "test-machine" in output.out
@@ -274,7 +274,7 @@ def created_machine_and_tags(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--tags", "--tags",
"server", "server",
] ],
) )
assert "test-machine" in output.out assert "test-machine" in output.out
@@ -291,7 +291,7 @@ def created_machine_and_tags(
"--tags", "--tags",
"test", "test",
"server", "server",
] ],
) )
assert "test-machine" in output.out assert "test-machine" in output.out
@@ -310,7 +310,7 @@ def created_machine_and_tags(
}; };
machine-without-tags = { }; machine-without-tags = { };
}; };
}""" }""",
}, },
], ],
indirect=True, indirect=True,
@@ -334,7 +334,7 @@ def list_mixed_tagged_untagged(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--tags", "--tags",
"tag1", "tag1",
] ],
) )
assert "machine-with-tags" in output.out assert "machine-with-tags" in output.out
@@ -349,7 +349,7 @@ def list_mixed_tagged_untagged(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--tags", "--tags",
"nonexistent", "nonexistent",
] ],
) )
assert "machine-with-tags" not in output.out assert "machine-with-tags" not in output.out
@@ -358,7 +358,8 @@ def list_mixed_tagged_untagged(
def test_machines_list_require_flake_error( def test_machines_list_require_flake_error(
temporary_home: Path, monkeypatch: pytest.MonkeyPatch temporary_home: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
"""Test that machines list command fails when flake is required but not provided.""" """Test that machines list command fails when flake is required but not provided."""
monkeypatch.chdir(temporary_home) monkeypatch.chdir(temporary_home)

View File

@@ -15,7 +15,7 @@ from clan_cli.tests.fixtures_flakes import FlakeForTest
machines.jon1 = { }; machines.jon1 = { };
machines.jon2 = { machineClass = "nixos"; }; machines.jon2 = { machineClass = "nixos"; };
machines.sara = { machineClass = "darwin"; }; machines.sara = { machineClass = "darwin"; };
}""" }""",
}, },
], ],
# Important! # Important!
@@ -27,8 +27,7 @@ from clan_cli.tests.fixtures_flakes import FlakeForTest
def test_inventory_machine_detect_class( def test_inventory_machine_detect_class(
test_flake_with_core: FlakeForTest, test_flake_with_core: FlakeForTest,
) -> None: ) -> None:
""" """Testing different inventory deserializations
Testing different inventory deserializations
Inventory should always be deserializable to a dict Inventory should always be deserializable to a dict
""" """
machine_jon1 = Machine( machine_jon1 = Machine(

View File

@@ -87,7 +87,8 @@ def get_machines_for_update(
) -> list[Machine]: ) -> list[Machine]:
all_machines = list_machines(flake) all_machines = list_machines(flake)
machines_with_tags = list_machines( machines_with_tags = list_machines(
flake, ListOptions(filter=MachineFilter(tags=filter_tags)) flake,
ListOptions(filter=MachineFilter(tags=filter_tags)),
) )
if filter_tags and not machines_with_tags: if filter_tags and not machines_with_tags:
@@ -101,7 +102,7 @@ def get_machines_for_update(
filter( filter(
requires_explicit_update, requires_explicit_update,
instantiate_inventory_to_machines(flake, machines_with_tags).values(), instantiate_inventory_to_machines(flake, machines_with_tags).values(),
) ),
) )
# all machines that are in the clan but not included in the update list # all machines that are in the clan but not included in the update list
machine_names_to_update = [m.name for m in machines_to_update] machine_names_to_update = [m.name for m in machines_to_update]
@@ -131,7 +132,7 @@ def get_machines_for_update(
raise ClanError(msg) raise ClanError(msg)
machines_to_update.append( machines_to_update.append(
Machine.from_inventory(name, flake, inventory_machine) Machine.from_inventory(name, flake, inventory_machine),
) )
return machines_to_update return machines_to_update
@@ -163,7 +164,7 @@ def update_command(args: argparse.Namespace) -> None:
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.core.vars.settings.secretModule", f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.core.vars.settings.secretModule",
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.deployment.requireExplicitUpdate", f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.deployment.requireExplicitUpdate",
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.system.clan.deployment.nixosMobileWorkaround", f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.system.clan.deployment.nixosMobileWorkaround",
] ],
) )
host_key_check = args.host_key_check host_key_check = args.host_key_check

View File

@@ -17,12 +17,12 @@ from clan_cli.tests.helpers import cli
"inventory_expr": r"""{ "inventory_expr": r"""{
machines.jon = { tags = [ "foo" "bar" ]; }; machines.jon = { tags = [ "foo" "bar" ]; };
machines.sara = { tags = [ "foo" "baz" ]; }; machines.sara = { tags = [ "foo" "baz" ]; };
}""" }""",
}, },
["jon"], # explizit names ["jon"], # explizit names
[], # filter tags [], # filter tags
["jon"], # expected ["jon"], # expected
) ),
], ],
# Important! # Important!
# tells pytest to pass these values to the fixture # tells pytest to pass these values to the fixture
@@ -55,12 +55,12 @@ def test_get_machines_for_update_single_name(
"inventory_expr": r"""{ "inventory_expr": r"""{
machines.jon = { tags = [ "foo" "bar" ]; }; machines.jon = { tags = [ "foo" "bar" ]; };
machines.sara = { tags = [ "foo" "baz" ]; }; machines.sara = { tags = [ "foo" "baz" ]; };
}""" }""",
}, },
[], # explizit names [], # explizit names
["foo"], # filter tags ["foo"], # filter tags
["jon", "sara"], # expected ["jon", "sara"], # expected
) ),
], ],
# Important! # Important!
# tells pytest to pass these values to the fixture # tells pytest to pass these values to the fixture
@@ -93,12 +93,12 @@ def test_get_machines_for_update_tags(
"inventory_expr": r"""{ "inventory_expr": r"""{
machines.jon = { tags = [ "foo" "bar" ]; }; machines.jon = { tags = [ "foo" "bar" ]; };
machines.sara = { tags = [ "foo" "baz" ]; }; machines.sara = { tags = [ "foo" "baz" ]; };
}""" }""",
}, },
["sara"], # explizit names ["sara"], # explizit names
["foo"], # filter tags ["foo"], # filter tags
["sara"], # expected ["sara"], # expected
) ),
], ],
# Important! # Important!
# tells pytest to pass these values to the fixture # tells pytest to pass these values to the fixture
@@ -131,7 +131,7 @@ def test_get_machines_for_update_tags_and_name(
"inventory_expr": r"""{ "inventory_expr": r"""{
machines.jon = { tags = [ "foo" "bar" ]; }; machines.jon = { tags = [ "foo" "bar" ]; };
machines.sara = { tags = [ "foo" "baz" ]; }; machines.sara = { tags = [ "foo" "baz" ]; };
}""" }""",
}, },
[], # no explizit names [], # no explizit names
[], # no filter tags [], # no filter tags
@@ -162,7 +162,8 @@ def test_get_machines_for_update_implicit_all(
def test_update_command_no_flake( def test_update_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -19,7 +19,7 @@ def list_command(args: argparse.Namespace) -> None:
col_network = max(12, max(len(name) for name in networks)) col_network = max(12, max(len(name) for name in networks))
col_priority = 8 col_priority = 8
col_module = max( col_module = max(
10, max(len(net.module_name.split(".")[-1]) for net in networks.values()) 10, max(len(net.module_name.split(".")[-1]) for net in networks.values()),
) )
col_running = 8 col_running = 8
@@ -30,7 +30,8 @@ def list_command(args: argparse.Namespace) -> None:
# Print network entries # Print network entries
for network_name, network in sorted( for network_name, network in sorted(
networks.items(), key=lambda network: -network[1].priority networks.items(),
key=lambda network: -network[1].priority,
): ):
# Extract simple module name from full module path # Extract simple module name from full module path
module_name = network.module_name.split(".")[-1] module_name = network.module_name.split(".")[-1]
@@ -56,7 +57,7 @@ def list_command(args: argparse.Namespace) -> None:
running_status = "Error" running_status = "Error"
print( print(
f"{network_name:<{col_network}} {network.priority:<{col_priority}} {module_name:<{col_module}} {running_status:<{col_running}} {peers_str}" f"{network_name:<{col_network}} {network.priority:<{col_priority}} {module_name:<{col_module}} {running_status:<{col_running}} {peers_str}",
) )

View File

@@ -95,8 +95,7 @@ PROFS = ProfilerStore()
def profile(func: Callable) -> Callable: def profile(func: Callable) -> Callable:
""" """A decorator that profiles the decorated function, printing out the profiling
A decorator that profiles the decorated function, printing out the profiling
results with paths trimmed to three directories deep. results with paths trimmed to three directories deep.
""" """

View File

@@ -39,7 +39,8 @@ class QgaSession:
def run_nonblocking(self, cmd: list[str]) -> int: def run_nonblocking(self, cmd: list[str]) -> int:
result_pid = self.client.cmd( result_pid = self.client.cmd(
"guest-exec", {"path": cmd[0], "arg": cmd[1:], "capture-output": True} "guest-exec",
{"path": cmd[0], "arg": cmd[1:], "capture-output": True},
) )
if result_pid is None: if result_pid is None:
msg = "Could not get PID from QGA" msg = "Could not get PID from QGA"

View File

@@ -20,32 +20,23 @@ from clan_lib.errors import ClanError
class QMPError(Exception): class QMPError(Exception):
""" """QMP base exception"""
QMP base exception
"""
class QMPConnectError(QMPError): class QMPConnectError(QMPError):
""" """QMP connection exception"""
QMP connection exception
"""
class QMPCapabilitiesError(QMPError): class QMPCapabilitiesError(QMPError):
""" """QMP negotiate capabilities exception"""
QMP negotiate capabilities exception
"""
class QMPTimeoutError(QMPError): class QMPTimeoutError(QMPError):
""" """QMP timeout exception"""
QMP timeout exception
"""
class QEMUMonitorProtocol: class QEMUMonitorProtocol:
""" """Provide an API to connect to QEMU via QEMU Monitor Protocol (QMP) and then
Provide an API to connect to QEMU via QEMU Monitor Protocol (QMP) and then
allow to handle commands and events. allow to handle commands and events.
""" """
@@ -58,8 +49,7 @@ class QEMUMonitorProtocol:
server: bool = False, server: bool = False,
nickname: str | None = None, nickname: str | None = None,
) -> None: ) -> None:
""" """Create a QEMUMonitorProtocol class.
Create a QEMUMonitorProtocol class.
@param address: QEMU address, can be either a unix socket path (string) @param address: QEMU address, can be either a unix socket path (string)
or a tuple in the form ( address, port ) for a TCP or a tuple in the form ( address, port ) for a TCP
@@ -109,8 +99,7 @@ class QEMUMonitorProtocol:
return resp return resp
def __get_events(self, wait: bool | float = False) -> None: def __get_events(self, wait: bool | float = False) -> None:
""" """Check for new events in the stream and cache them in __events.
Check for new events in the stream and cache them in __events.
@param wait (bool): block until an event is available. @param wait (bool): block until an event is available.
@param wait (float): If wait is a float, treat it as a timeout value. @param wait (float): If wait is a float, treat it as a timeout value.
@@ -120,7 +109,6 @@ class QEMUMonitorProtocol:
@raise QMPConnectError: If wait is True but no events could be @raise QMPConnectError: If wait is True but no events could be
retrieved or if some other error occurred. retrieved or if some other error occurred.
""" """
# Check for new events regardless and pull them into the cache: # Check for new events regardless and pull them into the cache:
self.__sock.setblocking(0) self.__sock.setblocking(0)
try: try:
@@ -163,8 +151,7 @@ class QEMUMonitorProtocol:
self.close() self.close()
def connect(self, negotiate: bool = True) -> dict[str, Any] | None: def connect(self, negotiate: bool = True) -> dict[str, Any] | None:
""" """Connect to the QMP Monitor and perform capabilities negotiation.
Connect to the QMP Monitor and perform capabilities negotiation.
@return QMP greeting dict, or None if negotiate is false @return QMP greeting dict, or None if negotiate is false
@raise OSError on socket connection errors @raise OSError on socket connection errors
@@ -178,8 +165,7 @@ class QEMUMonitorProtocol:
return None return None
def accept(self, timeout: float | None = 15.0) -> dict[str, Any]: def accept(self, timeout: float | None = 15.0) -> dict[str, Any]:
""" """Await connection from QMP Monitor and perform capabilities negotiation.
Await connection from QMP Monitor and perform capabilities negotiation.
@param timeout: timeout in seconds (nonnegative float number, or @param timeout: timeout in seconds (nonnegative float number, or
None). The value passed will set the behavior of the None). The value passed will set the behavior of the
@@ -199,8 +185,7 @@ class QEMUMonitorProtocol:
return self.__negotiate_capabilities() return self.__negotiate_capabilities()
def cmd_obj(self, qmp_cmd: dict[str, Any]) -> dict[str, Any] | None: def cmd_obj(self, qmp_cmd: dict[str, Any]) -> dict[str, Any] | None:
""" """Send a QMP command to the QMP Monitor.
Send a QMP command to the QMP Monitor.
@param qmp_cmd: QMP command to be sent as a Python dict @param qmp_cmd: QMP command to be sent as a Python dict
@return QMP response as a Python dict or None if the connection has @return QMP response as a Python dict or None if the connection has
@@ -223,8 +208,7 @@ class QEMUMonitorProtocol:
args: dict[str, Any] | None = None, args: dict[str, Any] | None = None,
cmd_id: dict[str, Any] | list[Any] | str | int | None = None, cmd_id: dict[str, Any] | list[Any] | str | int | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
""" """Build a QMP command and send it to the QMP Monitor.
Build a QMP command and send it to the QMP Monitor.
@param name: command name (string) @param name: command name (string)
@param args: command arguments (dict) @param args: command arguments (dict)
@@ -238,17 +222,14 @@ class QEMUMonitorProtocol:
return self.cmd_obj(qmp_cmd) return self.cmd_obj(qmp_cmd)
def command(self, cmd: str, **kwds: Any) -> Any: def command(self, cmd: str, **kwds: Any) -> Any:
""" """Build and send a QMP command to the monitor, report errors if any"""
Build and send a QMP command to the monitor, report errors if any
"""
ret = self.cmd(cmd, kwds) ret = self.cmd(cmd, kwds)
if "error" in ret: if "error" in ret:
raise ClanError(ret["error"]["desc"]) raise ClanError(ret["error"]["desc"])
return ret["return"] return ret["return"]
def pull_event(self, wait: bool | float = False) -> dict[str, Any] | None: def pull_event(self, wait: bool | float = False) -> dict[str, Any] | None:
""" """Pulls a single event.
Pulls a single event.
@param wait (bool): block until an event is available. @param wait (bool): block until an event is available.
@param wait (float): If wait is a float, treat it as a timeout value. @param wait (float): If wait is a float, treat it as a timeout value.
@@ -267,8 +248,7 @@ class QEMUMonitorProtocol:
return None return None
def get_events(self, wait: bool | float = False) -> list[dict[str, Any]]: def get_events(self, wait: bool | float = False) -> list[dict[str, Any]]:
""" """Get a list of available QMP events.
Get a list of available QMP events.
@param wait (bool): block until an event is available. @param wait (bool): block until an event is available.
@param wait (float): If wait is a float, treat it as a timeout value. @param wait (float): If wait is a float, treat it as a timeout value.
@@ -284,23 +264,18 @@ class QEMUMonitorProtocol:
return self.__events return self.__events
def clear_events(self) -> None: def clear_events(self) -> None:
""" """Clear current list of pending events."""
Clear current list of pending events.
"""
self.__events = [] self.__events = []
def close(self) -> None: def close(self) -> None:
""" """Close the socket and socket file."""
Close the socket and socket file.
"""
if self.__sock: if self.__sock:
self.__sock.close() self.__sock.close()
if self.__sockfile: if self.__sockfile:
self.__sockfile.close() self.__sockfile.close()
def settimeout(self, timeout: float | None) -> None: def settimeout(self, timeout: float | None) -> None:
""" """Set the socket timeout.
Set the socket timeout.
@param timeout (float): timeout in seconds, or None. @param timeout (float): timeout in seconds, or None.
@note This is a wrap around socket.settimeout @note This is a wrap around socket.settimeout
@@ -308,16 +283,14 @@ class QEMUMonitorProtocol:
self.__sock.settimeout(timeout) self.__sock.settimeout(timeout)
def get_sock_fd(self) -> int: def get_sock_fd(self) -> int:
""" """Get the socket file descriptor.
Get the socket file descriptor.
@return The file descriptor number. @return The file descriptor number.
""" """
return self.__sock.fileno() return self.__sock.fileno()
def is_scm_available(self) -> bool: def is_scm_available(self) -> bool:
""" """Check if the socket allows for SCM_RIGHTS.
Check if the socket allows for SCM_RIGHTS.
@return True if SCM_RIGHTS is available, otherwise False. @return True if SCM_RIGHTS is available, otherwise False.
""" """

View File

@@ -41,7 +41,11 @@ def users_folder(flake_dir: Path, group: str) -> Path:
class Group: class Group:
def __init__( def __init__(
self, flake_dir: Path, name: str, machines: list[str], users: list[str] self,
flake_dir: Path,
name: str,
machines: list[str],
users: list[str],
) -> None: ) -> None:
self.name = name self.name = name
self.machines = machines self.machines = machines
@@ -235,13 +239,18 @@ def remove_machine_command(args: argparse.Namespace) -> None:
def add_group_argument(parser: argparse.ArgumentParser) -> None: def add_group_argument(parser: argparse.ArgumentParser) -> None:
group_action = parser.add_argument( group_action = parser.add_argument(
"group", help="the name of the secret", type=group_name_type "group",
help="the name of the secret",
type=group_name_type,
) )
add_dynamic_completer(group_action, complete_groups) add_dynamic_completer(group_action, complete_groups)
def add_secret( def add_secret(
flake_dir: Path, group: str, name: str, age_plugins: list[str] | None flake_dir: Path,
group: str,
name: str,
age_plugins: list[str] | None,
) -> None: ) -> None:
secrets.allow_member( secrets.allow_member(
secrets.groups_folder(sops_secrets_folder(flake_dir) / name), secrets.groups_folder(sops_secrets_folder(flake_dir) / name),
@@ -276,7 +285,10 @@ def add_secret_command(args: argparse.Namespace) -> None:
def remove_secret( def remove_secret(
flake_dir: Path, group: str, name: str, age_plugins: list[str] flake_dir: Path,
group: str,
name: str,
age_plugins: list[str],
) -> None: ) -> None:
updated_paths = secrets.disallow_member( updated_paths = secrets.disallow_member(
secrets.groups_folder(sops_secrets_folder(flake_dir) / name), secrets.groups_folder(sops_secrets_folder(flake_dir) / name),
@@ -313,22 +325,28 @@ def register_groups_parser(parser: argparse.ArgumentParser) -> None:
# Add user # Add user
add_machine_parser = subparser.add_parser( add_machine_parser = subparser.add_parser(
"add-machine", help="add a machine to group" "add-machine",
help="add a machine to group",
) )
add_group_argument(add_machine_parser) add_group_argument(add_machine_parser)
add_machine_action = add_machine_parser.add_argument( add_machine_action = add_machine_parser.add_argument(
"machine", help="the name of the machines to add", type=machine_name_type "machine",
help="the name of the machines to add",
type=machine_name_type,
) )
add_dynamic_completer(add_machine_action, complete_machines) add_dynamic_completer(add_machine_action, complete_machines)
add_machine_parser.set_defaults(func=add_machine_command) add_machine_parser.set_defaults(func=add_machine_command)
# Remove machine # Remove machine
remove_machine_parser = subparser.add_parser( remove_machine_parser = subparser.add_parser(
"remove-machine", help="remove a machine from group" "remove-machine",
help="remove a machine from group",
) )
add_group_argument(remove_machine_parser) add_group_argument(remove_machine_parser)
remove_machine_action = remove_machine_parser.add_argument( remove_machine_action = remove_machine_parser.add_argument(
"machine", help="the name of the machines to remove", type=machine_name_type "machine",
help="the name of the machines to remove",
type=machine_name_type,
) )
add_dynamic_completer(remove_machine_action, complete_machines) add_dynamic_completer(remove_machine_action, complete_machines)
remove_machine_parser.set_defaults(func=remove_machine_command) remove_machine_parser.set_defaults(func=remove_machine_command)
@@ -337,40 +355,51 @@ def register_groups_parser(parser: argparse.ArgumentParser) -> None:
add_user_parser = subparser.add_parser("add-user", help="add a user to group") add_user_parser = subparser.add_parser("add-user", help="add a user to group")
add_group_argument(add_user_parser) add_group_argument(add_user_parser)
add_user_action = add_user_parser.add_argument( add_user_action = add_user_parser.add_argument(
"user", help="the name of the user to add", type=user_name_type "user",
help="the name of the user to add",
type=user_name_type,
) )
add_dynamic_completer(add_user_action, complete_users) add_dynamic_completer(add_user_action, complete_users)
add_user_parser.set_defaults(func=add_user_command) add_user_parser.set_defaults(func=add_user_command)
# Remove user # Remove user
remove_user_parser = subparser.add_parser( remove_user_parser = subparser.add_parser(
"remove-user", help="remove a user from a group" "remove-user",
help="remove a user from a group",
) )
add_group_argument(remove_user_parser) add_group_argument(remove_user_parser)
remove_user_action = remove_user_parser.add_argument( remove_user_action = remove_user_parser.add_argument(
"user", help="the name of the user to remove", type=user_name_type "user",
help="the name of the user to remove",
type=user_name_type,
) )
add_dynamic_completer(remove_user_action, complete_users) add_dynamic_completer(remove_user_action, complete_users)
remove_user_parser.set_defaults(func=remove_user_command) remove_user_parser.set_defaults(func=remove_user_command)
# Add secret # Add secret
add_secret_parser = subparser.add_parser( add_secret_parser = subparser.add_parser(
"add-secret", help="allow a groups to access a secret" "add-secret",
help="allow a groups to access a secret",
) )
add_group_argument(add_secret_parser) add_group_argument(add_secret_parser)
add_secret_action = add_secret_parser.add_argument( add_secret_action = add_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type "secret",
help="the name of the secret",
type=secret_name_type,
) )
add_dynamic_completer(add_secret_action, complete_secrets) add_dynamic_completer(add_secret_action, complete_secrets)
add_secret_parser.set_defaults(func=add_secret_command) add_secret_parser.set_defaults(func=add_secret_command)
# Remove secret # Remove secret
remove_secret_parser = subparser.add_parser( remove_secret_parser = subparser.add_parser(
"remove-secret", help="remove a group's access to a secret" "remove-secret",
help="remove a group's access to a secret",
) )
add_group_argument(remove_secret_parser) add_group_argument(remove_secret_parser)
remove_secret_action = remove_secret_parser.add_argument( remove_secret_action = remove_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type "secret",
help="the name of the secret",
type=secret_name_type,
) )
add_dynamic_completer(remove_secret_action, complete_secrets) add_dynamic_completer(remove_secret_action, complete_secrets)
remove_secret_parser.set_defaults(func=remove_secret_command) remove_secret_parser.set_defaults(func=remove_secret_command)

View File

@@ -19,8 +19,7 @@ log = logging.getLogger(__name__)
def generate_key() -> sops.SopsKey: def generate_key() -> sops.SopsKey:
""" """Generate a new age key and return it as a SopsKey.
Generate a new age key and return it as a SopsKey.
This function does not check if the key already exists. This function does not check if the key already exists.
It will generate a new key every time it is called. It will generate a new key every time it is called.
@@ -28,14 +27,16 @@ def generate_key() -> sops.SopsKey:
Use 'check_key_exists' to check if a key already exists. Use 'check_key_exists' to check if a key already exists.
Before calling this function if you dont want to generate a new key. Before calling this function if you dont want to generate a new key.
""" """
path = default_admin_private_key_path() path = default_admin_private_key_path()
_, pub_key = generate_private_key(out_file=path) _, pub_key = generate_private_key(out_file=path)
log.info( log.info(
f"Generated age private key at '{path}' for your user.\nPlease back it up on a secure location or you will lose access to your secrets." f"Generated age private key at '{path}' for your user.\nPlease back it up on a secure location or you will lose access to your secrets.",
) )
return sops.SopsKey( return sops.SopsKey(
pub_key, username="", key_type=sops.KeyType.AGE, source=str(path) pub_key,
username="",
key_type=sops.KeyType.AGE,
source=str(path),
) )
@@ -49,7 +50,8 @@ def generate_command(args: argparse.Namespace) -> None:
key_type = key.key_type.name.lower() key_type = key.key_type.name.lower()
print(f"{key.key_type.name} key {key.pubkey} is already set", file=sys.stderr) print(f"{key.key_type.name} key {key.pubkey} is already set", file=sys.stderr)
print( print(
f"Add your {key_type} public key to the repository with:", file=sys.stderr f"Add your {key_type} public key to the repository with:",
file=sys.stderr,
) )
print( print(
f"clan secrets users add <username> --{key_type}-key {key.pubkey}", f"clan secrets users add <username> --{key_type}-key {key.pubkey}",

View File

@@ -59,16 +59,12 @@ def get_machine_pubkey(flake_dir: Path, name: str) -> str:
def has_machine(flake_dir: Path, name: str) -> bool: def has_machine(flake_dir: Path, name: str) -> bool:
""" """Checks if a machine exists in the sops machines folder"""
Checks if a machine exists in the sops machines folder
"""
return (sops_machines_folder(flake_dir) / name / "key.json").exists() return (sops_machines_folder(flake_dir) / name / "key.json").exists()
def list_sops_machines(flake_dir: Path) -> list[str]: def list_sops_machines(flake_dir: Path) -> list[str]:
""" """Lists all machines in the sops machines folder"""
Lists all machines in the sops machines folder
"""
path = sops_machines_folder(flake_dir) path = sops_machines_folder(flake_dir)
def validate(name: str) -> bool: def validate(name: str) -> bool:
@@ -97,7 +93,10 @@ def add_secret(
def remove_secret( def remove_secret(
flake_dir: Path, machine: str, secret: str, age_plugins: list[str] | None flake_dir: Path,
machine: str,
secret: str,
age_plugins: list[str] | None,
) -> None: ) -> None:
updated_paths = secrets.disallow_member( updated_paths = secrets.disallow_member(
secrets.machines_folder(sops_secrets_folder(flake_dir) / secret), secrets.machines_folder(sops_secrets_folder(flake_dir) / secret),
@@ -174,7 +173,9 @@ def register_machines_parser(parser: argparse.ArgumentParser) -> None:
default=False, default=False,
) )
add_machine_action = add_parser.add_argument( add_machine_action = add_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type "machine",
help="the name of the machine",
type=machine_name_type,
) )
add_dynamic_completer(add_machine_action, complete_machines) add_dynamic_completer(add_machine_action, complete_machines)
add_parser.add_argument( add_parser.add_argument(
@@ -187,7 +188,9 @@ def register_machines_parser(parser: argparse.ArgumentParser) -> None:
# Parser # Parser
get_parser = subparser.add_parser("get", help="get a machine public key") get_parser = subparser.add_parser("get", help="get a machine public key")
get_machine_parser = get_parser.add_argument( get_machine_parser = get_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type "machine",
help="the name of the machine",
type=machine_name_type,
) )
add_dynamic_completer(get_machine_parser, complete_machines) add_dynamic_completer(get_machine_parser, complete_machines)
get_parser.set_defaults(func=get_command) get_parser.set_defaults(func=get_command)
@@ -195,35 +198,47 @@ def register_machines_parser(parser: argparse.ArgumentParser) -> None:
# Parser # Parser
remove_parser = subparser.add_parser("remove", help="remove a machine") remove_parser = subparser.add_parser("remove", help="remove a machine")
remove_machine_parser = remove_parser.add_argument( remove_machine_parser = remove_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type "machine",
help="the name of the machine",
type=machine_name_type,
) )
add_dynamic_completer(remove_machine_parser, complete_machines) add_dynamic_completer(remove_machine_parser, complete_machines)
remove_parser.set_defaults(func=remove_command) remove_parser.set_defaults(func=remove_command)
# Parser # Parser
add_secret_parser = subparser.add_parser( add_secret_parser = subparser.add_parser(
"add-secret", help="allow a machine to access a secret" "add-secret",
help="allow a machine to access a secret",
) )
machine_add_secret_parser = add_secret_parser.add_argument( machine_add_secret_parser = add_secret_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type "machine",
help="the name of the machine",
type=machine_name_type,
) )
add_dynamic_completer(machine_add_secret_parser, complete_machines) add_dynamic_completer(machine_add_secret_parser, complete_machines)
add_secret_action = add_secret_parser.add_argument( add_secret_action = add_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type "secret",
help="the name of the secret",
type=secret_name_type,
) )
add_dynamic_completer(add_secret_action, complete_secrets) add_dynamic_completer(add_secret_action, complete_secrets)
add_secret_parser.set_defaults(func=add_secret_command) add_secret_parser.set_defaults(func=add_secret_command)
# Parser # Parser
remove_secret_parser = subparser.add_parser( remove_secret_parser = subparser.add_parser(
"remove-secret", help="remove a group's access to a secret" "remove-secret",
help="remove a group's access to a secret",
) )
machine_remove_parser = remove_secret_parser.add_argument( machine_remove_parser = remove_secret_parser.add_argument(
"machine", help="the name of the machine", type=machine_name_type "machine",
help="the name of the machine",
type=machine_name_type,
) )
add_dynamic_completer(machine_remove_parser, complete_machines) add_dynamic_completer(machine_remove_parser, complete_machines)
remove_secret_action = remove_secret_parser.add_argument( remove_secret_action = remove_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type "secret",
help="the name of the secret",
type=secret_name_type,
) )
add_dynamic_completer(remove_secret_action, complete_secrets) add_dynamic_completer(remove_secret_action, complete_secrets)
remove_secret_parser.set_defaults(func=remove_secret_command) remove_secret_parser.set_defaults(func=remove_secret_command)

View File

@@ -50,7 +50,8 @@ def list_generators_secrets(generators_path: Path) -> list[Path]:
return has_secret(generator_path / name) return has_secret(generator_path / name)
for obj in list_objects( for obj in list_objects(
generator_path, functools.partial(validate, generator_path) generator_path,
functools.partial(validate, generator_path),
): ):
paths.append(generator_path / obj) paths.append(generator_path / obj)
return paths return paths
@@ -89,7 +90,7 @@ def update_secrets(
changed_files.extend(cleanup_dangling_symlinks(path / "groups")) changed_files.extend(cleanup_dangling_symlinks(path / "groups"))
changed_files.extend(cleanup_dangling_symlinks(path / "machines")) changed_files.extend(cleanup_dangling_symlinks(path / "machines"))
changed_files.extend( changed_files.extend(
update_keys(path, collect_keys_for_path(path), age_plugins=age_plugins) update_keys(path, collect_keys_for_path(path), age_plugins=age_plugins),
) )
return changed_files return changed_files
@@ -120,7 +121,7 @@ def collect_keys_for_type(folder: Path) -> set[sops.SopsKey]:
kind = target.parent.name kind = target.parent.name
if folder.name != kind: if folder.name != kind:
log.warning( log.warning(
f"Expected {p} to point to {folder} but points to {target.parent}" f"Expected {p} to point to {folder} but points to {target.parent}",
) )
continue continue
keys.update(read_keys(target)) keys.update(read_keys(target))
@@ -160,7 +161,7 @@ def encrypt_secret(
admin_keys = sops.ensure_admin_public_keys(flake_dir) admin_keys = sops.ensure_admin_public_keys(flake_dir)
if not admin_keys: if not admin_keys:
# todo double check the correct command to run # TODO double check the correct command to run
msg = "No keys found. Please run 'clan secrets add-key' to add a key." msg = "No keys found. Please run 'clan secrets add-key' to add a key."
raise ClanError(msg) raise ClanError(msg)
@@ -179,7 +180,7 @@ def encrypt_secret(
user, user,
do_update_keys, do_update_keys,
age_plugins=age_plugins, age_plugins=age_plugins,
) ),
) )
for machine in add_machines: for machine in add_machines:
@@ -190,7 +191,7 @@ def encrypt_secret(
machine, machine,
do_update_keys, do_update_keys,
age_plugins=age_plugins, age_plugins=age_plugins,
) ),
) )
for group in add_groups: for group in add_groups:
@@ -201,7 +202,7 @@ def encrypt_secret(
group, group,
do_update_keys, do_update_keys,
age_plugins=age_plugins, age_plugins=age_plugins,
) ),
) )
recipient_keys = collect_keys_for_path(secret_path) recipient_keys = collect_keys_for_path(secret_path)
@@ -216,7 +217,7 @@ def encrypt_secret(
username, username,
do_update_keys, do_update_keys,
age_plugins=age_plugins, age_plugins=age_plugins,
) ),
) )
secret_path = secret_path / "secret" secret_path = secret_path / "secret"
@@ -310,13 +311,15 @@ def allow_member(
group_folder.parent, group_folder.parent,
collect_keys_for_path(group_folder.parent), collect_keys_for_path(group_folder.parent),
age_plugins=age_plugins, age_plugins=age_plugins,
) ),
) )
return changed return changed
def disallow_member( def disallow_member(
group_folder: Path, name: str, age_plugins: list[str] | None group_folder: Path,
name: str,
age_plugins: list[str] | None,
) -> list[Path]: ) -> list[Path]:
target = group_folder / name target = group_folder / name
if not target.exists(): if not target.exists():
@@ -349,7 +352,8 @@ def has_secret(secret_path: Path) -> bool:
def list_secrets( def list_secrets(
flake_dir: Path, filter_fn: Callable[[str], bool] | None = None flake_dir: Path,
filter_fn: Callable[[str], bool] | None = None,
) -> list[str]: ) -> list[str]:
path = sops_secrets_folder(flake_dir) path = sops_secrets_folder(flake_dir)

View File

@@ -66,7 +66,7 @@ class KeyType(enum.Enum):
for public_key in get_public_age_keys(content): for public_key in get_public_age_keys(content):
log.debug( log.debug(
f"Found age public key from a private key " f"Found age public key from a private key "
f"in {key_path}: {public_key}" f"in {key_path}: {public_key}",
) )
keyring.append( keyring.append(
@@ -75,7 +75,7 @@ class KeyType(enum.Enum):
username="", username="",
key_type=self, key_type=self,
source=str(key_path), source=str(key_path),
) ),
) )
except ClanError as e: except ClanError as e:
error_msg = f"Failed to read age keys from {key_path}" error_msg = f"Failed to read age keys from {key_path}"
@@ -96,7 +96,7 @@ class KeyType(enum.Enum):
for public_key in get_public_age_keys(content): for public_key in get_public_age_keys(content):
log.debug( log.debug(
f"Found age public key from a private key " f"Found age public key from a private key "
f"in the environment (SOPS_AGE_KEY): {public_key}" f"in the environment (SOPS_AGE_KEY): {public_key}",
) )
keyring.append( keyring.append(
@@ -105,7 +105,7 @@ class KeyType(enum.Enum):
username="", username="",
key_type=self, key_type=self,
source="SOPS_AGE_KEY", source="SOPS_AGE_KEY",
) ),
) )
except ClanError as e: except ClanError as e:
error_msg = "Failed to read age keys from SOPS_AGE_KEY" error_msg = "Failed to read age keys from SOPS_AGE_KEY"
@@ -126,8 +126,11 @@ class KeyType(enum.Enum):
log.debug(msg) log.debug(msg)
keyring.append( keyring.append(
SopsKey( SopsKey(
pubkey=fp, username="", key_type=self, source="SOPS_PGP_FP" pubkey=fp,
) username="",
key_type=self,
source="SOPS_PGP_FP",
),
) )
return keyring return keyring
@@ -389,7 +392,7 @@ def get_user_name(flake_dir: Path, user: str) -> str:
"""Ask the user for their name until a unique one is provided.""" """Ask the user for their name until a unique one is provided."""
while True: while True:
name = input( name = input(
f"Your key is not yet added to the repository. Enter your user name for which your sops key will be stored in the repository [default: {user}]: " f"Your key is not yet added to the repository. Enter your user name for which your sops key will be stored in the repository [default: {user}]: ",
) )
if name: if name:
user = name user = name
@@ -455,7 +458,9 @@ def ensure_admin_public_keys(flake_dir: Path) -> set[SopsKey]:
def update_keys( def update_keys(
secret_path: Path, keys: Iterable[SopsKey], age_plugins: list[str] | None = None secret_path: Path,
keys: Iterable[SopsKey],
age_plugins: list[str] | None = None,
) -> list[Path]: ) -> list[Path]:
secret_path = secret_path / "secret" secret_path = secret_path / "secret"
error_msg = f"Could not update keys for {secret_path}" error_msg = f"Could not update keys for {secret_path}"
@@ -565,7 +570,7 @@ def get_recipients(secret_path: Path) -> set[SopsKey]:
username="", username="",
key_type=key_type, key_type=key_type,
source="sops_file", source="sops_file",
) ),
) )
return keys return keys

View File

@@ -66,7 +66,7 @@ def remove_user(flake_dir: Path, name: str) -> None:
continue continue
log.info(f"Removing user {name} from group {group}") log.info(f"Removing user {name} from group {group}")
updated_paths.extend( updated_paths.extend(
groups.remove_member(flake_dir, group.name, groups.users_folder, name) groups.remove_member(flake_dir, group.name, groups.users_folder, name),
) )
# Remove the user's key: # Remove the user's key:
updated_paths.extend(remove_object(sops_users_folder(flake_dir), name)) updated_paths.extend(remove_object(sops_users_folder(flake_dir), name))
@@ -96,7 +96,10 @@ def list_users(flake_dir: Path) -> list[str]:
def add_secret( def add_secret(
flake_dir: Path, user: str, secret: str, age_plugins: list[str] | None flake_dir: Path,
user: str,
secret: str,
age_plugins: list[str] | None,
) -> None: ) -> None:
updated_paths = secrets.allow_member( updated_paths = secrets.allow_member(
secrets.users_folder(sops_secrets_folder(flake_dir) / secret), secrets.users_folder(sops_secrets_folder(flake_dir) / secret),
@@ -112,10 +115,15 @@ def add_secret(
def remove_secret( def remove_secret(
flake_dir: Path, user: str, secret: str, age_plugins: list[str] | None flake_dir: Path,
user: str,
secret: str,
age_plugins: list[str] | None,
) -> None: ) -> None:
updated_paths = secrets.disallow_member( updated_paths = secrets.disallow_member(
secrets.users_folder(sops_secrets_folder(flake_dir) / secret), user, age_plugins secrets.users_folder(sops_secrets_folder(flake_dir) / secret),
user,
age_plugins,
) )
commit_files( commit_files(
updated_paths, updated_paths,
@@ -189,7 +197,7 @@ def _key_args(args: argparse.Namespace) -> Iterable[sops.SopsKey]:
] ]
if args.agekey: if args.agekey:
age_keys.append( age_keys.append(
sops.SopsKey(args.agekey, "", sops.KeyType.AGE, source="cmdline") sops.SopsKey(args.agekey, "", sops.KeyType.AGE, source="cmdline"),
) )
pgp_keys = [ pgp_keys = [
@@ -260,7 +268,10 @@ def register_users_parser(parser: argparse.ArgumentParser) -> None:
add_parser = subparser.add_parser("add", help="add a user") add_parser = subparser.add_parser("add", help="add a user")
add_parser.add_argument( add_parser.add_argument(
"-f", "--force", help="overwrite existing user", action="store_true" "-f",
"--force",
help="overwrite existing user",
action="store_true",
) )
add_parser.add_argument("user", help="the name of the user", type=user_name_type) add_parser.add_argument("user", help="the name of the user", type=user_name_type)
_add_key_flags(add_parser) _add_key_flags(add_parser)
@@ -268,59 +279,79 @@ def register_users_parser(parser: argparse.ArgumentParser) -> None:
get_parser = subparser.add_parser("get", help="get a user public key") get_parser = subparser.add_parser("get", help="get a user public key")
get_user_action = get_parser.add_argument( get_user_action = get_parser.add_argument(
"user", help="the name of the user", type=user_name_type "user",
help="the name of the user",
type=user_name_type,
) )
add_dynamic_completer(get_user_action, complete_users) add_dynamic_completer(get_user_action, complete_users)
get_parser.set_defaults(func=get_command) get_parser.set_defaults(func=get_command)
remove_parser = subparser.add_parser("remove", help="remove a user") remove_parser = subparser.add_parser("remove", help="remove a user")
remove_user_action = remove_parser.add_argument( remove_user_action = remove_parser.add_argument(
"user", help="the name of the user", type=user_name_type "user",
help="the name of the user",
type=user_name_type,
) )
add_dynamic_completer(remove_user_action, complete_users) add_dynamic_completer(remove_user_action, complete_users)
remove_parser.set_defaults(func=remove_command) remove_parser.set_defaults(func=remove_command)
add_secret_parser = subparser.add_parser( add_secret_parser = subparser.add_parser(
"add-secret", help="allow a user to access a secret" "add-secret",
help="allow a user to access a secret",
) )
add_secret_user_action = add_secret_parser.add_argument( add_secret_user_action = add_secret_parser.add_argument(
"user", help="the name of the user", type=user_name_type "user",
help="the name of the user",
type=user_name_type,
) )
add_dynamic_completer(add_secret_user_action, complete_users) add_dynamic_completer(add_secret_user_action, complete_users)
add_secrets_action = add_secret_parser.add_argument( add_secrets_action = add_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type "secret",
help="the name of the secret",
type=secret_name_type,
) )
add_dynamic_completer(add_secrets_action, complete_secrets) add_dynamic_completer(add_secrets_action, complete_secrets)
add_secret_parser.set_defaults(func=add_secret_command) add_secret_parser.set_defaults(func=add_secret_command)
remove_secret_parser = subparser.add_parser( remove_secret_parser = subparser.add_parser(
"remove-secret", help="remove a user's access to a secret" "remove-secret",
help="remove a user's access to a secret",
) )
remove_secret_user_action = remove_secret_parser.add_argument( remove_secret_user_action = remove_secret_parser.add_argument(
"user", help="the name of the group", type=user_name_type "user",
help="the name of the group",
type=user_name_type,
) )
add_dynamic_completer(remove_secret_user_action, complete_users) add_dynamic_completer(remove_secret_user_action, complete_users)
remove_secrets_action = remove_secret_parser.add_argument( remove_secrets_action = remove_secret_parser.add_argument(
"secret", help="the name of the secret", type=secret_name_type "secret",
help="the name of the secret",
type=secret_name_type,
) )
add_dynamic_completer(remove_secrets_action, complete_secrets) add_dynamic_completer(remove_secrets_action, complete_secrets)
remove_secret_parser.set_defaults(func=remove_secret_command) remove_secret_parser.set_defaults(func=remove_secret_command)
add_key_parser = subparser.add_parser( add_key_parser = subparser.add_parser(
"add-key", help="add one or more keys for a user" "add-key",
help="add one or more keys for a user",
) )
add_key_user_action = add_key_parser.add_argument( add_key_user_action = add_key_parser.add_argument(
"user", help="the name of the user", type=user_name_type "user",
help="the name of the user",
type=user_name_type,
) )
add_dynamic_completer(add_key_user_action, complete_users) add_dynamic_completer(add_key_user_action, complete_users)
_add_key_flags(add_key_parser) _add_key_flags(add_key_parser)
add_key_parser.set_defaults(func=add_key_command) add_key_parser.set_defaults(func=add_key_command)
remove_key_parser = subparser.add_parser( remove_key_parser = subparser.add_parser(
"remove-key", help="remove one or more keys for a user" "remove-key",
help="remove one or more keys for a user",
) )
remove_key_user_action = remove_key_parser.add_argument( remove_key_user_action = remove_key_parser.add_argument(
"user", help="the name of the user", type=user_name_type "user",
help="the name of the user",
type=user_name_type,
) )
add_dynamic_completer(remove_key_user_action, complete_users) add_dynamic_completer(remove_key_user_action, complete_users)
_add_key_flags(remove_key_parser) _add_key_flags(remove_key_parser)

View File

@@ -64,7 +64,8 @@ def ssh_command(args: argparse.Namespace) -> None:
ssh_options[name] = value ssh_options[name] = value
remote = remote.override( remote = remote.override(
host_key_check=args.host_key_check, ssh_options=ssh_options host_key_check=args.host_key_check,
ssh_options=ssh_options,
) )
if args.remote_command: if args.remote_command:
remote.interactive_ssh(args.remote_command) remote.interactive_ssh(args.remote_command)

View File

@@ -147,7 +147,7 @@ def test_ssh_shell_from_deploy(
str(success_txt), str(success_txt),
"&&", "&&",
"exit 0", "exit 0",
] ],
) )
assert success_txt.exists() assert success_txt.exists()

View File

@@ -25,7 +25,7 @@ def list_state_folders(machine: Machine, service: None | str = None) -> None:
[ [
f"{flake}#nixosConfigurations.{machine.name}.config.clan.core.state", f"{flake}#nixosConfigurations.{machine.name}.config.clan.core.state",
"--json", "--json",
] ],
) )
res = "{}" res = "{}"
@@ -80,7 +80,7 @@ def list_state_folders(machine: Machine, service: None | str = None) -> None:
if post_restore: if post_restore:
print(f" postRestoreCommand: {post_restore}") print(f" postRestoreCommand: {post_restore}")
print("") print()
def list_command(args: argparse.Namespace) -> None: def list_command(args: argparse.Namespace) -> None:

View File

@@ -7,7 +7,8 @@ from clan_cli.tests.stdout import CaptureOutput
@pytest.mark.with_core @pytest.mark.with_core
def test_state_list_vm1( def test_state_list_vm1(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run(["state", "list", "vm1", "--flake", str(test_flake_with_core.path)]) cli.run(["state", "list", "vm1", "--flake", str(test_flake_with_core.path)])
@@ -19,7 +20,8 @@ def test_state_list_vm1(
@pytest.mark.with_core @pytest.mark.with_core
def test_state_list_vm2( def test_state_list_vm2(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run(["state", "list", "vm2", "--flake", str(test_flake_with_core.path)]) cli.run(["state", "list", "vm2", "--flake", str(test_flake_with_core.path)])

View File

@@ -15,7 +15,8 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
) )
list_parser = subparser.add_parser("list", help="List available templates") list_parser = subparser.add_parser("list", help="List available templates")
apply_parser = subparser.add_parser( apply_parser = subparser.add_parser(
"apply", help="Apply a template of the specified type" "apply",
help="Apply a template of the specified type",
) )
register_list_parser(list_parser) register_list_parser(list_parser)
register_apply_parser(apply_parser) register_apply_parser(apply_parser)

View File

@@ -12,10 +12,11 @@ def list_command(args: argparse.Namespace) -> None:
# Display all templates # Display all templates
for i, (template_type, _builtin_template_set) in enumerate( for i, (template_type, _builtin_template_set) in enumerate(
templates.builtins.items() templates.builtins.items(),
): ):
builtin_template_set: TemplateClanType | None = templates.builtins.get( builtin_template_set: TemplateClanType | None = templates.builtins.get(
template_type, None template_type,
None,
) # type: ignore ) # type: ignore
if not builtin_template_set: if not builtin_template_set:
continue continue
@@ -32,7 +33,8 @@ def list_command(args: argparse.Namespace) -> None:
for i, (input_name, input_templates) in enumerate(templates.custom.items()): for i, (input_name, input_templates) in enumerate(templates.custom.items()):
custom_templates: TemplateClanType | None = input_templates.get( custom_templates: TemplateClanType | None = input_templates.get(
template_type, None template_type,
None,
) # type: ignore ) # type: ignore
if not custom_templates: if not custom_templates:
continue continue
@@ -48,11 +50,11 @@ def list_command(args: argparse.Namespace) -> None:
is_last_template = i == len(custom_templates.items()) - 1 is_last_template = i == len(custom_templates.items()) - 1
if not is_last_template: if not is_last_template:
print( print(
f"{prefix} ├── {name}: {template.get('description', 'no description')}" f"{prefix} ├── {name}: {template.get('description', 'no description')}",
) )
else: else:
print( print(
f"{prefix} └── {name}: {template.get('description', 'no description')}" f"{prefix} └── {name}: {template.get('description', 'no description')}",
) )

View File

@@ -9,7 +9,8 @@ from clan_cli.tests.stdout import CaptureOutput
@pytest.mark.with_core @pytest.mark.with_core
def test_templates_list( def test_templates_list(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run(["templates", "list", "--flake", str(test_flake_with_core.path)]) cli.run(["templates", "list", "--flake", str(test_flake_with_core.path)])
@@ -26,7 +27,8 @@ def test_templates_list(
@pytest.mark.with_core @pytest.mark.with_core
def test_templates_list_outside_clan( def test_templates_list_outside_clan(
capture_output: CaptureOutput, temp_dir: Path capture_output: CaptureOutput,
temp_dir: Path,
) -> None: ) -> None:
"""Test templates list command when run outside a clan directory.""" """Test templates list command when run outside a clan directory."""
with capture_output as output: with capture_output as output:

View File

@@ -37,7 +37,7 @@ class SopsSetup:
"--user", "--user",
self.user, self.user,
"--no-interactive", "--no-interactive",
] ],
) )

View File

@@ -54,8 +54,7 @@ class Command:
@pytest.fixture @pytest.fixture
def command() -> Iterator[Command]: def command() -> Iterator[Command]:
""" """Starts a background command. The process is automatically terminated in the end.
Starts a background command. The process is automatically terminated in the end.
>>> p = command.run(["some", "daemon"]) >>> p = command.run(["some", "daemon"])
>>> print(p.pid) >>> print(p.pid)
""" """

View File

@@ -39,8 +39,7 @@ def def_value() -> defaultdict:
def nested_dict() -> defaultdict: def nested_dict() -> defaultdict:
""" """Creates a defaultdict that allows for arbitrary levels of nesting.
Creates a defaultdict that allows for arbitrary levels of nesting.
For example: d['a']['b']['c'] = value For example: d['a']['b']['c'] = value
""" """
return defaultdict(def_value) return defaultdict(def_value)
@@ -75,7 +74,8 @@ def substitute(
if clan_core_replacement: if clan_core_replacement:
line = line.replace("__CLAN_CORE__", clan_core_replacement) line = line.replace("__CLAN_CORE__", clan_core_replacement)
line = line.replace( line = line.replace(
"git+https://git.clan.lol/clan/clan-core", clan_core_replacement "git+https://git.clan.lol/clan/clan-core",
clan_core_replacement,
) )
line = line.replace( line = line.replace(
"https://git.clan.lol/clan/clan-core/archive/main.tar.gz", "https://git.clan.lol/clan/clan-core/archive/main.tar.gz",
@@ -133,8 +133,7 @@ def init_git(monkeypatch: pytest.MonkeyPatch, flake: Path) -> None:
class ClanFlake: class ClanFlake:
""" """This class holds all attributes for generating a clan flake.
This class holds all attributes for generating a clan flake.
For example, inventory and machine configs can be set via self.inventory and self.machines["my_machine"] = {...}. For example, inventory and machine configs can be set via self.inventory and self.machines["my_machine"] = {...}.
Whenever a flake's configuration is changed, it needs to be re-generated by calling refresh(). Whenever a flake's configuration is changed, it needs to be re-generated by calling refresh().
@@ -179,7 +178,7 @@ class ClanFlake:
if not suppress_tmp_home_warning: if not suppress_tmp_home_warning:
if "/tmp" not in str(os.environ.get("HOME")): if "/tmp" not in str(os.environ.get("HOME")):
log.warning( log.warning(
f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}" f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}",
) )
def copy( def copy(
@@ -236,7 +235,7 @@ class ClanFlake:
inventory_path = self.path / "inventory.json" inventory_path = self.path / "inventory.json"
inventory_path.write_text(json.dumps(self.inventory, indent=2)) inventory_path.write_text(json.dumps(self.inventory, indent=2))
imports = "\n".join( imports = "\n".join(
[f"clan-core.clanModules.{module}" for module in self.clan_modules] [f"clan-core.clanModules.{module}" for module in self.clan_modules],
) )
for machine_name, machine_config in self.machines.items(): for machine_name, machine_config in self.machines.items():
configuration_nix = ( configuration_nix = (
@@ -252,7 +251,7 @@ class ClanFlake:
{imports} {imports}
]; ];
}} }}
""" """,
) )
machine = Machine(name=machine_name, flake=Flake(str(self.path))) machine = Machine(name=machine_name, flake=Flake(str(self.path)))
set_machine_settings(machine, machine_config) set_machine_settings(machine, machine_config)
@@ -309,8 +308,7 @@ def create_flake(
machine_configs: dict[str, dict] | None = None, machine_configs: dict[str, dict] | None = None,
inventory_expr: str = r"{}", inventory_expr: str = r"{}",
) -> Iterator[FlakeForTest]: ) -> Iterator[FlakeForTest]:
""" """Creates a flake with the given name and machines.
Creates a flake with the given name and machines.
The machine names map to the machines in ./test_machines The machine names map to the machines in ./test_machines
""" """
if machine_configs is None: if machine_configs is None:
@@ -372,7 +370,7 @@ def create_flake(
if "/tmp" not in str(os.environ.get("HOME")): if "/tmp" not in str(os.environ.get("HOME")):
log.warning( log.warning(
f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}" f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}",
) )
init_git(monkeypatch, flake) init_git(monkeypatch, flake)
@@ -382,7 +380,8 @@ def create_flake(
@pytest.fixture @pytest.fixture
def test_flake( def test_flake(
monkeypatch: pytest.MonkeyPatch, temporary_home: Path monkeypatch: pytest.MonkeyPatch,
temporary_home: Path,
) -> Iterator[FlakeForTest]: ) -> Iterator[FlakeForTest]:
yield from create_flake( yield from create_flake(
temporary_home=temporary_home, temporary_home=temporary_home,
@@ -429,8 +428,7 @@ def writable_clan_core(
clan_core: Path, clan_core: Path,
tmp_path: Path, tmp_path: Path,
) -> Path: ) -> Path:
""" """Creates a writable copy of clan_core in a temporary directory.
Creates a writable copy of clan_core in a temporary directory.
If clan_core is a git repo, copies tracked files and uncommitted changes. If clan_core is a git repo, copies tracked files and uncommitted changes.
Removes vars/ and sops/ directories if they exist. Removes vars/ and sops/ directories if they exist.
""" """
@@ -454,7 +452,9 @@ def writable_clan_core(
# Copy .git directory to maintain git functionality # Copy .git directory to maintain git functionality
if (clan_core / ".git").is_dir(): if (clan_core / ".git").is_dir():
shutil.copytree( shutil.copytree(
clan_core / ".git", temp_flake / ".git", ignore_dangling_symlinks=True clan_core / ".git",
temp_flake / ".git",
ignore_dangling_symlinks=True,
) )
else: else:
# It's a git file (for submodules/worktrees) # It's a git file (for submodules/worktrees)
@@ -478,9 +478,7 @@ def vm_test_flake(
clan_core: Path, clan_core: Path,
tmp_path: Path, tmp_path: Path,
) -> Path: ) -> Path:
""" """Creates a test flake that imports the VM test nixOS modules from clan-core."""
Creates a test flake that imports the VM test nixOS modules from clan-core.
"""
test_flake_dir = tmp_path / "test-flake" test_flake_dir = tmp_path / "test-flake"
test_flake_dir.mkdir(parents=True) test_flake_dir.mkdir(parents=True)

View File

@@ -18,7 +18,7 @@ def hosts(sshd: Sshd) -> list[Remote]:
private_key=Path(sshd.key), private_key=Path(sshd.key),
host_key_check="none", host_key_check="none",
command_prefix="local_test", command_prefix="local_test",
) ),
] ]
return group return group

View File

@@ -13,31 +13,23 @@ else:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def project_root() -> Path: def project_root() -> Path:
""" """Root directory the clan-cli"""
Root directory the clan-cli
"""
return PROJECT_ROOT return PROJECT_ROOT
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_root() -> Path: def test_root() -> Path:
""" """Root directory of the tests"""
Root directory of the tests
"""
return TEST_ROOT return TEST_ROOT
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_lib_root() -> Path: def test_lib_root() -> Path:
""" """Root directory of the clan-lib tests"""
Root directory of the clan-lib tests
"""
return PROJECT_ROOT.parent / "clan_lib" / "tests" return PROJECT_ROOT.parent / "clan_lib" / "tests"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def clan_core() -> Path: def clan_core() -> Path:
""" """Directory of the clan-core flake"""
Directory of the clan-core flake
"""
return CLAN_CORE return CLAN_CORE

View File

@@ -29,7 +29,12 @@ class Sshd:
class SshdConfig: class SshdConfig:
def __init__( def __init__(
self, path: Path, login_shell: Path, key: str, preload_lib: Path, log_file: Path self,
path: Path,
login_shell: Path,
key: str,
preload_lib: Path,
log_file: Path,
) -> None: ) -> None:
self.path = path self.path = path
self.login_shell = login_shell self.login_shell = login_shell
@@ -53,7 +58,7 @@ def sshd_config(test_root: Path) -> Iterator[SshdConfig]:
sftp_server = sshdp.parent.parent / "libexec" / "sftp-server" sftp_server = sshdp.parent.parent / "libexec" / "sftp-server"
assert sftp_server is not None assert sftp_server is not None
content = string.Template(template).substitute( content = string.Template(template).substitute(
{"host_key": host_key, "sftp_server": sftp_server} {"host_key": host_key, "sftp_server": sftp_server},
) )
config = tmpdir / "sshd_config" config = tmpdir / "sshd_config"
config.write_text(content) config.write_text(content)
@@ -74,7 +79,7 @@ if [[ -f /etc/profile ]]; then
fi fi
export PATH="{bin_path}:{path}" export PATH="{bin_path}:{path}"
exec {bash} -l "${{@}}" exec {bash} -l "${{@}}"
""" """,
) )
login_shell.chmod(0o755) login_shell.chmod(0o755)
@@ -82,7 +87,7 @@ exec {bash} -l "${{@}}"
f"""#!{bash} f"""#!{bash}
shift shift
exec "${{@}}" exec "${{@}}"
""" """,
) )
fake_sudo.chmod(0o755) fake_sudo.chmod(0o755)

View File

@@ -21,16 +21,17 @@ def should_skip(file_path: Path, excludes: list[Path]) -> bool:
def find_dataclasses_in_directory( def find_dataclasses_in_directory(
directory: Path, exclude_paths: list[str] | None = None directory: Path,
exclude_paths: list[str] | None = None,
) -> list[tuple[Path, str]]: ) -> list[tuple[Path, str]]:
""" """Find all dataclass classes in all Python files within a nested directory.
Find all dataclass classes in all Python files within a nested directory.
Args: Args:
directory (str): The root directory to start searching from. directory (str): The root directory to start searching from.
Returns: Returns:
List[Tuple[str, str]]: A list of tuples containing the file path and the dataclass name. List[Tuple[str, str]]: A list of tuples containing the file path and the dataclass name.
""" """
if exclude_paths is None: if exclude_paths is None:
exclude_paths = [] exclude_paths = []
@@ -69,10 +70,11 @@ def find_dataclasses_in_directory(
def load_dataclass_from_file( def load_dataclass_from_file(
file_path: Path, class_name: str, root_dir: str file_path: Path,
class_name: str,
root_dir: str,
) -> type | None: ) -> type | None:
""" """Load a dataclass from a given file path.
Load a dataclass from a given file path.
Args: Args:
file_path (str): Path to the file. file_path (str): Path to the file.
@@ -80,6 +82,7 @@ def load_dataclass_from_file(
Returns: Returns:
List[Type]: The dataclass type if found, else an empty list. List[Type]: The dataclass type if found, else an empty list.
""" """
module_name = ( module_name = (
os.path.relpath(file_path, root_dir).replace(os.path.sep, ".").rstrip(".py") os.path.relpath(file_path, root_dir).replace(os.path.sep, ".").rstrip(".py")
@@ -109,15 +112,14 @@ def load_dataclass_from_file(
dataclass_type = getattr(module, class_name, None) dataclass_type = getattr(module, class_name, None)
if dataclass_type and is_dataclass(dataclass_type): if dataclass_type and is_dataclass(dataclass_type):
return cast(type, dataclass_type) return cast("type", dataclass_type)
msg = f"Could not load dataclass {class_name} from file: {file_path}" msg = f"Could not load dataclass {class_name} from file: {file_path}"
raise ClanError(msg) raise ClanError(msg)
def test_all_dataclasses() -> None: def test_all_dataclasses() -> None:
""" """This Test ensures that all dataclasses are compatible with the API.
This Test ensures that all dataclasses are compatible with the API.
It will load all dataclasses from the clan_cli directory and It will load all dataclasses from the clan_cli directory and
generate a JSON schema for each of them. generate a JSON schema for each of them.
@@ -125,7 +127,6 @@ def test_all_dataclasses() -> None:
It will fail if any dataclass cannot be converted to JSON schema. It will fail if any dataclass cannot be converted to JSON schema.
This means the dataclass in its current form is not compatible with the API. This means the dataclass in its current form is not compatible with the API.
""" """
# Excludes: # Excludes:
# - API includes Type Generic wrappers, that are not known in the init file. # - API includes Type Generic wrappers, that are not known in the init file.
excludes = [ excludes = [

View File

@@ -14,5 +14,5 @@ def test_backups(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"vm1", "vm1",
] ],
) )

View File

@@ -139,7 +139,7 @@ def test_create_flake_fallback_from_non_clan_directory(
monkeypatch.setenv("LOGNAME", "testuser") monkeypatch.setenv("LOGNAME", "testuser")
cli.run( cli.run(
["flakes", "create", str(new_clan_dir), "--template=default", "--no-update"] ["flakes", "create", str(new_clan_dir), "--template=default", "--no-update"],
) )
assert (new_clan_dir / "flake.nix").exists() assert (new_clan_dir / "flake.nix").exists()
@@ -157,7 +157,7 @@ def test_create_flake_with_local_template_reference(
# TODO: should error with: localFlake does not export myLocalTemplate clan template # TODO: should error with: localFlake does not export myLocalTemplate clan template
cli.run( cli.run(
["flakes", "create", str(new_clan_dir), "--template=.#default", "--no-update"] ["flakes", "create", str(new_clan_dir), "--template=.#default", "--no-update"],
) )
assert (new_clan_dir / "flake.nix").exists() assert (new_clan_dir / "flake.nix").exists()

View File

@@ -1,17 +1,13 @@
from typing import TYPE_CHECKING
import pytest import pytest
from clan_cli.tests.fixtures_flakes import FlakeForTest from clan_cli.tests.fixtures_flakes import FlakeForTest
from clan_cli.tests.helpers import cli from clan_cli.tests.helpers import cli
from clan_cli.tests.stdout import CaptureOutput from clan_cli.tests.stdout import CaptureOutput
if TYPE_CHECKING:
pass
@pytest.mark.with_core @pytest.mark.with_core
def test_flakes_inspect( def test_flakes_inspect(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run( cli.run(
@@ -22,6 +18,6 @@ def test_flakes_inspect(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--machine", "--machine",
"vm1", "vm1",
] ],
) )
assert "Icon" in output.out assert "Icon" in output.out

View File

@@ -19,7 +19,8 @@ def test_commit_file(git_repo: Path) -> None:
# check that the latest commit message is correct # check that the latest commit message is correct
assert ( assert (
subprocess.check_output( subprocess.check_output(
["git", "log", "-1", "--pretty=%B"], cwd=git_repo ["git", "log", "-1", "--pretty=%B"],
cwd=git_repo,
).decode("utf-8") ).decode("utf-8")
== "test commit\n\n" == "test commit\n\n"
) )
@@ -59,7 +60,8 @@ def test_clan_flake_in_subdir(git_repo: Path, monkeypatch: pytest.MonkeyPatch) -
# check that the latest commit message is correct # check that the latest commit message is correct
assert ( assert (
subprocess.check_output( subprocess.check_output(
["git", "log", "-1", "--pretty=%B"], cwd=git_repo ["git", "log", "-1", "--pretty=%B"],
cwd=git_repo,
).decode("utf-8") ).decode("utf-8")
== "test commit\n\n" == "test commit\n\n"
) )

View File

@@ -28,7 +28,7 @@ def test_import_sops(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"machine1", "machine1",
age_keys[0].pubkey, age_keys[0].pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -39,7 +39,7 @@ def test_import_sops(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"user1", "user1",
age_keys[1].pubkey, age_keys[1].pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -50,7 +50,7 @@ def test_import_sops(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"user2", "user2",
age_keys[2].pubkey, age_keys[2].pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -61,7 +61,7 @@ def test_import_sops(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"user1", "user1",
] ],
) )
cli.run( cli.run(
[ [
@@ -72,7 +72,7 @@ def test_import_sops(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"user2", "user2",
] ],
) )
# To edit: # To edit:
@@ -98,6 +98,6 @@ def test_import_sops(
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "secret-key"] ["secrets", "get", "--flake", str(test_flake_with_core.path), "secret-key"],
) )
assert output.out == "secret-value" assert output.out == "secret-value"

View File

@@ -16,7 +16,7 @@ from clan_lib.persist.inventory_store import InventoryStore
"inventory_expr": r"""{ "inventory_expr": r"""{
machines.jon = {}; machines.jon = {};
machines.sara = {}; machines.sara = {};
}""" }""",
}, },
# TODO: Test # TODO: Test
# - Function modules # - Function modules
@@ -38,14 +38,13 @@ from clan_lib.persist.inventory_store import InventoryStore
def test_inventory_deserialize_variants( def test_inventory_deserialize_variants(
test_flake_with_core: FlakeForTest, test_flake_with_core: FlakeForTest,
) -> None: ) -> None:
""" """Testing different inventory deserializations
Testing different inventory deserializations
Inventory should always be deserializable to a dict Inventory should always be deserializable to a dict
""" """
inventory_store = InventoryStore(Flake(str(test_flake_with_core.path))) inventory_store = InventoryStore(Flake(str(test_flake_with_core.path)))
# Cast the inventory to a dict for the following assertions # Cast the inventory to a dict for the following assertions
inventory = cast(dict[str, Any], inventory_store.read()) inventory = cast("dict[str, Any]", inventory_store.read())
# Check that the inventory is a dict # Check that the inventory is a dict
assert isinstance(inventory, dict) assert isinstance(inventory, dict)

View File

@@ -27,7 +27,7 @@ def test_machine_subcommands(
"machine1", "machine1",
"--tags", "--tags",
"vm", "vm",
] ],
) )
# Usually this is done by `inventory.write` but we created a separate flake object in the test that now holds stale data # Usually this is done by `inventory.write` but we created a separate flake object in the test that now holds stale data
inventory_store._flake.invalidate_cache() inventory_store._flake.invalidate_cache()
@@ -47,7 +47,7 @@ def test_machine_subcommands(
assert "vm2" in output.out assert "vm2" in output.out
cli.run( cli.run(
["machines", "delete", "--flake", str(test_flake_with_core.path), "machine1"] ["machines", "delete", "--flake", str(test_flake_with_core.path), "machine1"],
) )
# See comment above # See comment above
inventory_store._flake.invalidate_cache() inventory_store._flake.invalidate_cache()
@@ -105,7 +105,7 @@ def test_machines_update_nonexistent_machine(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"nonexistent-machine", "nonexistent-machine",
] ],
) )
error_message = str(exc_info.value) error_message = str(exc_info.value)
@@ -130,7 +130,7 @@ def test_machines_update_typo_in_machine_name(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"v1", # typo of "vm1" "v1", # typo of "vm1"
] ],
) )
error_message = str(exc_info.value) error_message = str(exc_info.value)

View File

@@ -51,7 +51,7 @@ def _test_identities(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"foo", "foo",
age_keys[0].pubkey, age_keys[0].pubkey,
] ],
) )
assert (sops_folder / what / "foo" / "key.json").exists() assert (sops_folder / what / "foo" / "key.json").exists()
@@ -64,7 +64,7 @@ def _test_identities(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin", "admin",
admin_age_key.pubkey, admin_age_key.pubkey,
] ],
) )
with pytest.raises(ClanError): # raises "foo already exists" with pytest.raises(ClanError): # raises "foo already exists"
@@ -77,7 +77,7 @@ def _test_identities(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"foo", "foo",
age_keys[0].pubkey, age_keys[0].pubkey,
] ],
) )
with monkeypatch.context() as m: with monkeypatch.context() as m:
@@ -93,7 +93,7 @@ def _test_identities(
f"--{what_singular}", f"--{what_singular}",
"foo", "foo",
test_secret_name, test_secret_name,
] ],
) )
assert_secrets_file_recipients( assert_secrets_file_recipients(
@@ -114,7 +114,7 @@ def _test_identities(
"-f", "-f",
"foo", "foo",
age_keys[1].privkey, age_keys[1].privkey,
] ],
) )
assert_secrets_file_recipients( assert_secrets_file_recipients(
test_flake_with_core.path, test_flake_with_core.path,
@@ -131,7 +131,7 @@ def _test_identities(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"foo", "foo",
] ],
) )
assert age_keys[1].pubkey in output.out assert age_keys[1].pubkey in output.out
@@ -140,7 +140,7 @@ def _test_identities(
assert "foo" in output.out assert "foo" in output.out
cli.run( cli.run(
["secrets", what, "remove", "--flake", str(test_flake_with_core.path), "foo"] ["secrets", what, "remove", "--flake", str(test_flake_with_core.path), "foo"],
) )
assert not (sops_folder / what / "foo" / "key.json").exists() assert not (sops_folder / what / "foo" / "key.json").exists()
@@ -153,7 +153,7 @@ def _test_identities(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"foo", "foo",
] ],
) )
with capture_output as output: with capture_output as output:
@@ -178,7 +178,11 @@ def test_users(
) -> None: ) -> None:
with monkeypatch.context(): with monkeypatch.context():
_test_identities( _test_identities(
"users", test_flake_with_core, capture_output, age_keys, monkeypatch "users",
test_flake_with_core,
capture_output,
age_keys,
monkeypatch,
) )
@@ -208,7 +212,7 @@ def test_multiple_user_keys(
str(test_flake_with_core.path), str(test_flake_with_core.path),
user, user,
*[f"--age-key={key.pubkey}" for key in user_keys], *[f"--age-key={key.pubkey}" for key in user_keys],
] ],
) )
assert (sops_folder / "users" / user / "key.json").exists() assert (sops_folder / "users" / user / "key.json").exists()
@@ -222,7 +226,7 @@ def test_multiple_user_keys(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
user, user,
] ],
) )
for user_key in user_keys: for user_key in user_keys:
@@ -249,7 +253,7 @@ def test_multiple_user_keys(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
secret_name, secret_name,
] ],
) )
# check the secret has each of our user's keys as a recipient # check the secret has each of our user's keys as a recipient
@@ -268,7 +272,7 @@ def test_multiple_user_keys(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
secret_name, secret_name,
] ],
) )
assert secret_value in output.out assert secret_value in output.out
@@ -295,7 +299,7 @@ def test_multiple_user_keys(
str(test_flake_with_core.path), str(test_flake_with_core.path),
user, user,
key_to_remove.pubkey, key_to_remove.pubkey,
] ],
) )
# check the secret has been updated # check the secret has been updated
@@ -315,7 +319,7 @@ def test_multiple_user_keys(
str(test_flake_with_core.path), str(test_flake_with_core.path),
user, user,
key_to_remove.pubkey, key_to_remove.pubkey,
] ],
) )
# check the secret has been updated # check the secret has been updated
@@ -334,7 +338,11 @@ def test_machines(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
_test_identities( _test_identities(
"machines", test_flake_with_core, capture_output, age_keys, monkeypatch "machines",
test_flake_with_core,
capture_output,
age_keys,
monkeypatch,
) )
@@ -347,7 +355,7 @@ def test_groups(
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "groups", "list", "--flake", str(test_flake_with_core.path)] ["secrets", "groups", "list", "--flake", str(test_flake_with_core.path)],
) )
assert output.out == "" assert output.out == ""
@@ -365,7 +373,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"machine1", "machine1",
] ],
) )
with pytest.raises(ClanError): # user does not exist yet with pytest.raises(ClanError): # user does not exist yet
cli.run( cli.run(
@@ -377,7 +385,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"groupb1", "groupb1",
"user1", "user1",
] ],
) )
cli.run( cli.run(
[ [
@@ -388,7 +396,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"machine1", "machine1",
machine1_age_key.pubkey, machine1_age_key.pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -399,7 +407,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"machine1", "machine1",
] ],
) )
# Should this fail? # Should this fail?
@@ -412,7 +420,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"machine1", "machine1",
] ],
) )
cli.run( cli.run(
@@ -424,7 +432,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"user1", "user1",
user1_age_key.pubkey, user1_age_key.pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -435,7 +443,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin", "admin",
admin_age_key.pubkey, admin_age_key.pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -446,12 +454,12 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"user1", "user1",
] ],
) )
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "groups", "list", "--flake", str(test_flake_with_core.path)] ["secrets", "groups", "list", "--flake", str(test_flake_with_core.path)],
) )
out = output.out out = output.out
assert "user1" in out assert "user1" in out
@@ -472,7 +480,7 @@ def test_groups(
"--group", "--group",
"group1", "group1",
secret_name, secret_name,
] ],
) )
assert_secrets_file_recipients( assert_secrets_file_recipients(
@@ -498,7 +506,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"user1", "user1",
] ],
) )
assert_secrets_file_recipients( assert_secrets_file_recipients(
test_flake_with_core.path, test_flake_with_core.path,
@@ -520,7 +528,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"user1", "user1",
] ],
) )
assert_secrets_file_recipients( assert_secrets_file_recipients(
test_flake_with_core.path, test_flake_with_core.path,
@@ -541,7 +549,7 @@ def test_groups(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"user1", "user1",
] ],
) )
assert_secrets_file_recipients( assert_secrets_file_recipients(
test_flake_with_core.path, test_flake_with_core.path,
@@ -562,7 +570,7 @@ def test_groups(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"group1", "group1",
"machine1", "machine1",
] ],
) )
assert_secrets_file_recipients( assert_secrets_file_recipients(
test_flake_with_core.path, test_flake_with_core.path,
@@ -629,13 +637,15 @@ def test_secrets(
# Generate a new key for the clan # Generate a new key for the clan
monkeypatch.setenv( monkeypatch.setenv(
"SOPS_AGE_KEY_FILE", str(test_flake_with_core.path / ".." / "age.key") "SOPS_AGE_KEY_FILE",
str(test_flake_with_core.path / ".." / "age.key"),
) )
with patch( with patch(
"clan_cli.secrets.key.generate_private_key", wraps=generate_private_key "clan_cli.secrets.key.generate_private_key",
wraps=generate_private_key,
) as spy: ) as spy:
cli.run( cli.run(
["secrets", "key", "generate", "--flake", str(test_flake_with_core.path)] ["secrets", "key", "generate", "--flake", str(test_flake_with_core.path)],
) )
assert spy.call_count == 1 assert spy.call_count == 1
@@ -655,18 +665,24 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"testuser", "testuser",
key["publickey"], key["publickey"],
] ],
) )
with pytest.raises(ClanError): # does not exist yet with pytest.raises(ClanError): # does not exist yet
cli.run( cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "nonexisting"] [
"secrets",
"get",
"--flake",
str(test_flake_with_core.path),
"nonexisting",
],
) )
monkeypatch.setenv("SOPS_NIX_SECRET", "foo") monkeypatch.setenv("SOPS_NIX_SECRET", "foo")
cli.run(["secrets", "set", "--flake", str(test_flake_with_core.path), "initialkey"]) cli.run(["secrets", "set", "--flake", str(test_flake_with_core.path), "initialkey"])
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "initialkey"] ["secrets", "get", "--flake", str(test_flake_with_core.path), "initialkey"],
) )
assert output.out == "foo" assert output.out == "foo"
with capture_output as output: with capture_output as output:
@@ -684,7 +700,7 @@ def test_secrets(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"initialkey", "initialkey",
] ],
) )
monkeypatch.delenv("EDITOR") monkeypatch.delenv("EDITOR")
@@ -696,7 +712,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"initialkey", "initialkey",
"key", "key",
] ],
) )
with capture_output as output: with capture_output as output:
@@ -711,7 +727,7 @@ def test_secrets(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"nonexisting", "nonexisting",
] ],
) )
assert output.out == "" assert output.out == ""
@@ -730,7 +746,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"machine1", "machine1",
age_keys[1].pubkey, age_keys[1].pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -741,18 +757,18 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"machine1", "machine1",
"key", "key",
] ],
) )
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "machines", "list", "--flake", str(test_flake_with_core.path)] ["secrets", "machines", "list", "--flake", str(test_flake_with_core.path)],
) )
assert output.out == "machine1\n" assert output.out == "machine1\n"
with use_age_key(age_keys[1].privkey, monkeypatch): with use_age_key(age_keys[1].privkey, monkeypatch):
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"] ["secrets", "get", "--flake", str(test_flake_with_core.path), "key"],
) )
assert output.out == "foo" assert output.out == "foo"
@@ -767,14 +783,14 @@ def test_secrets(
"-f", "-f",
"machine1", "machine1",
age_keys[0].privkey, age_keys[0].privkey,
] ],
) )
# should also rotate the encrypted secret # should also rotate the encrypted secret
with use_age_key(age_keys[0].privkey, monkeypatch): with use_age_key(age_keys[0].privkey, monkeypatch):
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"] ["secrets", "get", "--flake", str(test_flake_with_core.path), "key"],
) )
assert output.out == "foo" assert output.out == "foo"
@@ -787,7 +803,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"machine1", "machine1",
"key", "key",
] ],
) )
cli.run( cli.run(
@@ -799,7 +815,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"user1", "user1",
age_keys[1].pubkey, age_keys[1].pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -810,7 +826,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"user1", "user1",
"key", "key",
] ],
) )
with capture_output as output, use_age_key(age_keys[1].privkey, monkeypatch): with capture_output as output, use_age_key(age_keys[1].privkey, monkeypatch):
cli.run(["secrets", "get", "--flake", str(test_flake_with_core.path), "key"]) cli.run(["secrets", "get", "--flake", str(test_flake_with_core.path), "key"])
@@ -824,7 +840,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"user1", "user1",
"key", "key",
] ],
) )
with pytest.raises(ClanError): # does not exist yet with pytest.raises(ClanError): # does not exist yet
@@ -837,7 +853,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin-group", "admin-group",
"key", "key",
] ],
) )
cli.run( cli.run(
[ [
@@ -848,7 +864,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin-group", "admin-group",
"user1", "user1",
] ],
) )
cli.run( cli.run(
[ [
@@ -859,7 +875,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin-group", "admin-group",
owner, owner,
] ],
) )
cli.run( cli.run(
[ [
@@ -870,7 +886,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin-group", "admin-group",
"key", "key",
] ],
) )
cli.run( cli.run(
@@ -882,13 +898,13 @@ def test_secrets(
"--group", "--group",
"admin-group", "admin-group",
"key2", "key2",
] ],
) )
with use_age_key(age_keys[1].privkey, monkeypatch): with use_age_key(age_keys[1].privkey, monkeypatch):
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"] ["secrets", "get", "--flake", str(test_flake_with_core.path), "key"],
) )
assert output.out == "foo" assert output.out == "foo"
@@ -903,7 +919,7 @@ def test_secrets(
"--pgp-key", "--pgp-key",
gpg_key.fingerprint, gpg_key.fingerprint,
"user2", "user2",
] ],
) )
# Extend group will update secrets # Extend group will update secrets
@@ -916,13 +932,13 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin-group", "admin-group",
"user2", "user2",
] ],
) )
with use_gpg_key(gpg_key, monkeypatch): # user2 with use_gpg_key(gpg_key, monkeypatch): # user2
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "get", "--flake", str(test_flake_with_core.path), "key"] ["secrets", "get", "--flake", str(test_flake_with_core.path), "key"],
) )
assert output.out == "foo" assert output.out == "foo"
@@ -935,7 +951,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin-group", "admin-group",
"user2", "user2",
] ],
) )
with ( with (
pytest.raises(ClanError), pytest.raises(ClanError),
@@ -955,7 +971,7 @@ def test_secrets(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admin-group", "admin-group",
"key", "key",
] ],
) )
cli.run(["secrets", "remove", "--flake", str(test_flake_with_core.path), "key"]) cli.run(["secrets", "remove", "--flake", str(test_flake_with_core.path), "key"])
@@ -979,7 +995,8 @@ def test_secrets_key_generate_gpg(
with ( with (
capture_output as output, capture_output as output,
patch( patch(
"clan_cli.secrets.key.generate_private_key", wraps=generate_private_key "clan_cli.secrets.key.generate_private_key",
wraps=generate_private_key,
) as spy_sops, ) as spy_sops,
): ):
cli.run( cli.run(
@@ -989,7 +1006,7 @@ def test_secrets_key_generate_gpg(
"generate", "generate",
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
] ],
) )
assert spy_sops.call_count == 0 assert spy_sops.call_count == 0
# assert "age private key" not in output.out # assert "age private key" not in output.out
@@ -1000,7 +1017,7 @@ def test_secrets_key_generate_gpg(
with capture_output as output: with capture_output as output:
cli.run( cli.run(
["secrets", "key", "show", "--flake", str(test_flake_with_core.path)] ["secrets", "key", "show", "--flake", str(test_flake_with_core.path)],
) )
key = json.loads(output.out)[0] key = json.loads(output.out)[0]
assert key["type"] == "pgp" assert key["type"] == "pgp"
@@ -1017,7 +1034,7 @@ def test_secrets_key_generate_gpg(
"--pgp-key", "--pgp-key",
gpg_key.fingerprint, gpg_key.fingerprint,
"testuser", "testuser",
] ],
) )
with capture_output as output: with capture_output as output:
@@ -1029,7 +1046,7 @@ def test_secrets_key_generate_gpg(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"testuser", "testuser",
] ],
) )
keys = json.loads(output.out) keys = json.loads(output.out)
assert len(keys) == 1 assert len(keys) == 1
@@ -1048,7 +1065,7 @@ def test_secrets_key_generate_gpg(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"secret-name", "secret-name",
] ],
) )
with capture_output as output: with capture_output as output:
cli.run( cli.run(
@@ -1058,7 +1075,7 @@ def test_secrets_key_generate_gpg(
"--flake", "--flake",
str(test_flake_with_core.path), str(test_flake_with_core.path),
"secret-name", "secret-name",
] ],
) )
assert output.out == "secret-value" assert output.out == "secret-value"
@@ -1078,7 +1095,7 @@ def test_secrets_users_add_age_plugin_error(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"testuser", "testuser",
"AGE-PLUGIN-YUBIKEY-18P5XCQVZ5FE4WKCW3NJWP", "AGE-PLUGIN-YUBIKEY-18P5XCQVZ5FE4WKCW3NJWP",
] ],
) )
error_msg = str(exc_info.value) error_msg = str(exc_info.value)

View File

@@ -31,7 +31,7 @@ def test_generate_secret(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"user1", "user1",
age_keys[0].pubkey, age_keys[0].pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -42,7 +42,7 @@ def test_generate_secret(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"admins", "admins",
"user1", "user1",
] ],
) )
cmd = [ cmd = [
"vars", "vars",
@@ -56,7 +56,7 @@ def test_generate_secret(
cli.run(cmd) cli.run(cmd)
store1 = SecretStore( store1 = SecretStore(
Machine(name="vm1", flake=Flake(str(test_flake_with_core.path))) Machine(name="vm1", flake=Flake(str(test_flake_with_core.path))),
) )
assert store1.exists("", "age.key") assert store1.exists("", "age.key")
@@ -97,13 +97,13 @@ def test_generate_secret(
str(test_flake_with_core.path), str(test_flake_with_core.path),
"--generator", "--generator",
"zerotier", "zerotier",
] ],
) )
assert age_key.lstat().st_mtime_ns == age_key_mtime assert age_key.lstat().st_mtime_ns == age_key_mtime
assert identity_secret.lstat().st_mtime_ns == secret1_mtime assert identity_secret.lstat().st_mtime_ns == secret1_mtime
store2 = SecretStore( store2 = SecretStore(
Machine(name="vm2", flake=Flake(str(test_flake_with_core.path))) Machine(name="vm2", flake=Flake(str(test_flake_with_core.path))),
) )
assert store2.exists("", "age.key") assert store2.exists("", "age.key")

View File

@@ -28,7 +28,10 @@ def test_run_environment(runtime: AsyncRuntime) -> None:
def test_run_local(runtime: AsyncRuntime) -> None: def test_run_local(runtime: AsyncRuntime) -> None:
p1 = runtime.async_run( p1 = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR) None,
host.run_local,
["echo", "hello"],
RunOpts(log=Log.STDERR),
) )
assert p1.wait().result.stdout == "hello\n" assert p1.wait().result.stdout == "hello\n"

View File

@@ -189,8 +189,8 @@ def test_generate_public_and_secret_vars(
nix_eval( nix_eval(
[ [
f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.value_with_default.value", f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.value_with_default.value",
] ],
) ),
).stdout.strip() ).stdout.strip()
assert json.loads(value_non_default) == "default_value" assert json.loads(value_non_default) == "default_value"
@@ -211,14 +211,17 @@ def test_generate_public_and_secret_vars(
public_value = get_machine_var(machine, "my_generator/my_value").printable_value public_value = get_machine_var(machine, "my_generator/my_value").printable_value
assert public_value.startswith("public") assert public_value.startswith("public")
shared_value = get_machine_var( shared_value = get_machine_var(
machine, "my_shared_generator/my_shared_value" machine,
"my_shared_generator/my_shared_value",
).printable_value ).printable_value
assert shared_value.startswith("shared") assert shared_value.startswith("shared")
vars_text = stringify_all_vars(machine) vars_text = stringify_all_vars(machine)
flake_obj = Flake(str(flake.path)) flake_obj = Flake(str(flake.path))
my_generator = Generator("my_generator", machine="my_machine", _flake=flake_obj) my_generator = Generator("my_generator", machine="my_machine", _flake=flake_obj)
dependent_generator = Generator( dependent_generator = Generator(
"dependent_generator", machine="my_machine", _flake=flake_obj "dependent_generator",
machine="my_machine",
_flake=flake_obj,
) )
in_repo_store = in_repo.FactStore(flake=flake_obj) in_repo_store = in_repo.FactStore(flake=flake_obj)
assert not in_repo_store.exists(my_generator, "my_secret") assert not in_repo_store.exists(my_generator, "my_secret")
@@ -235,8 +238,8 @@ def test_generate_public_and_secret_vars(
nix_eval( nix_eval(
[ [
f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.my_value.value", f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.my_value.value",
] ],
) ),
).stdout.strip() ).stdout.strip()
assert json.loads(vars_eval).startswith("public") assert json.loads(vars_eval).startswith("public")
@@ -244,14 +247,14 @@ def test_generate_public_and_secret_vars(
nix_eval( nix_eval(
[ [
f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.value_with_default.value", f"{flake.path}#nixosConfigurations.my_machine.config.clan.core.vars.generators.my_generator.files.value_with_default.value",
] ],
) ),
).stdout.strip() ).stdout.strip()
assert json.loads(value_non_default).startswith("non-default") assert json.loads(value_non_default).startswith("non-default")
# test regeneration works # test regeneration works
cli.run( cli.run(
["vars", "generate", "--flake", str(flake.path), "my_machine", "--regenerate"] ["vars", "generate", "--flake", str(flake.path), "my_machine", "--regenerate"],
) )
# test regeneration without sandbox # test regeneration without sandbox
cli.run( cli.run(
@@ -263,7 +266,7 @@ def test_generate_public_and_secret_vars(
"my_machine", "my_machine",
"--regenerate", "--regenerate",
"--no-sandbox", "--no-sandbox",
] ],
) )
# test stuff actually changed after regeneration # test stuff actually changed after regeneration
public_value_new = get_machine_var(machine, "my_generator/my_value").printable_value public_value_new = get_machine_var(machine, "my_generator/my_value").printable_value
@@ -273,7 +276,8 @@ def test_generate_public_and_secret_vars(
"Secret value should change after regeneration" "Secret value should change after regeneration"
) )
shared_value_new = get_machine_var( shared_value_new = get_machine_var(
machine, "my_shared_generator/my_shared_value" machine,
"my_shared_generator/my_shared_value",
).printable_value ).printable_value
assert shared_value != shared_value_new, ( assert shared_value != shared_value_new, (
"Shared value should change after regeneration" "Shared value should change after regeneration"
@@ -290,18 +294,20 @@ def test_generate_public_and_secret_vars(
"--no-sandbox", "--no-sandbox",
"--generator", "--generator",
"my_shared_generator", "my_shared_generator",
] ],
) )
# test that the shared generator is regenerated # test that the shared generator is regenerated
shared_value_after_regeneration = get_machine_var( shared_value_after_regeneration = get_machine_var(
machine, "my_shared_generator/my_shared_value" machine,
"my_shared_generator/my_shared_value",
).printable_value ).printable_value
assert shared_value_after_regeneration != shared_value_new, ( assert shared_value_after_regeneration != shared_value_new, (
"Shared value should change after regenerating my_shared_generator" "Shared value should change after regenerating my_shared_generator"
) )
# test that the dependent generator is also regenerated (because it depends on my_shared_generator) # test that the dependent generator is also regenerated (because it depends on my_shared_generator)
secret_value_after_regeneration = sops_store.get( secret_value_after_regeneration = sops_store.get(
dependent_generator, "my_secret" dependent_generator,
"my_secret",
).decode() ).decode()
assert secret_value_after_regeneration != secret_value_new, ( assert secret_value_after_regeneration != secret_value_new, (
"Dependent generator's secret should change after regenerating my_shared_generator" "Dependent generator's secret should change after regenerating my_shared_generator"
@@ -311,7 +317,8 @@ def test_generate_public_and_secret_vars(
) )
# test that my_generator is NOT regenerated (it doesn't depend on my_shared_generator) # test that my_generator is NOT regenerated (it doesn't depend on my_shared_generator)
public_value_after_regeneration = get_machine_var( public_value_after_regeneration = get_machine_var(
machine, "my_generator/my_value" machine,
"my_generator/my_value",
).printable_value ).printable_value
assert public_value_after_regeneration == public_value_new, ( assert public_value_after_regeneration == public_value_new, (
"my_generator value should NOT change after regenerating only my_shared_generator" "my_generator value should NOT change after regenerating only my_shared_generator"
@@ -348,10 +355,14 @@ def test_generate_secret_var_sops_with_default_group(
cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"]) cli.run(["vars", "generate", "--flake", str(flake.path), "my_machine"])
flake_obj = Flake(str(flake.path)) flake_obj = Flake(str(flake.path))
first_generator = Generator( first_generator = Generator(
"first_generator", machine="my_machine", _flake=flake_obj "first_generator",
machine="my_machine",
_flake=flake_obj,
) )
second_generator = Generator( second_generator = Generator(
"second_generator", machine="my_machine", _flake=flake_obj "second_generator",
machine="my_machine",
_flake=flake_obj,
) )
in_repo_store = in_repo.FactStore(flake=flake_obj) in_repo_store = in_repo.FactStore(flake=flake_obj)
assert not in_repo_store.exists(first_generator, "my_secret") assert not in_repo_store.exists(first_generator, "my_secret")
@@ -372,16 +383,22 @@ def test_generate_secret_var_sops_with_default_group(
str(flake.path), str(flake.path),
"user2", "user2",
pubkey_user2.pubkey, pubkey_user2.pubkey,
] ],
) )
cli.run(["secrets", "groups", "add-user", "my_group", "user2"]) cli.run(["secrets", "groups", "add-user", "my_group", "user2"])
# check if new user can access the secret # check if new user can access the secret
monkeypatch.setenv("USER", "user2") monkeypatch.setenv("USER", "user2")
first_generator_with_share = Generator( first_generator_with_share = Generator(
"first_generator", share=False, machine="my_machine", _flake=flake_obj "first_generator",
share=False,
machine="my_machine",
_flake=flake_obj,
) )
second_generator_with_share = Generator( second_generator_with_share = Generator(
"second_generator", share=False, machine="my_machine", _flake=flake_obj "second_generator",
share=False,
machine="my_machine",
_flake=flake_obj,
) )
assert sops_store.user_has_access("user2", first_generator_with_share, "my_secret") assert sops_store.user_has_access("user2", first_generator_with_share, "my_secret")
assert sops_store.user_has_access("user2", second_generator_with_share, "my_secret") assert sops_store.user_has_access("user2", second_generator_with_share, "my_secret")
@@ -398,7 +415,7 @@ def test_generate_secret_var_sops_with_default_group(
"--force", "--force",
"user2", "user2",
pubkey_user3.pubkey, pubkey_user3.pubkey,
] ],
) )
monkeypatch.setenv("USER", "user2") monkeypatch.setenv("USER", "user2")
assert sops_store.user_has_access("user2", first_generator_with_share, "my_secret") assert sops_store.user_has_access("user2", first_generator_with_share, "my_secret")
@@ -438,10 +455,16 @@ def test_generated_shared_secret_sops(
m2_sops_store = sops.SecretStore(machine2.flake) m2_sops_store = sops.SecretStore(machine2.flake)
# Create generators with machine context for testing # Create generators with machine context for testing
generator_m1 = Generator( generator_m1 = Generator(
"my_shared_generator", share=True, machine="machine1", _flake=machine1.flake "my_shared_generator",
share=True,
machine="machine1",
_flake=machine1.flake,
) )
generator_m2 = Generator( generator_m2 = Generator(
"my_shared_generator", share=True, machine="machine2", _flake=machine2.flake "my_shared_generator",
share=True,
machine="machine2",
_flake=machine2.flake,
) )
assert m1_sops_store.exists(generator_m1, "my_shared_secret") assert m1_sops_store.exists(generator_m1, "my_shared_secret")
@@ -492,7 +515,9 @@ def test_generate_secret_var_password_store(
check=True, check=True,
) )
subprocess.run( subprocess.run(
["git", "config", "user.name", "Test User"], cwd=password_store_dir, check=True ["git", "config", "user.name", "Test User"],
cwd=password_store_dir,
check=True,
) )
flake_obj = Flake(str(flake.path)) flake_obj = Flake(str(flake.path))
@@ -502,10 +527,18 @@ def test_generate_secret_var_password_store(
assert check_vars(machine.name, machine.flake) assert check_vars(machine.name, machine.flake)
store = password_store.SecretStore(flake=flake_obj) store = password_store.SecretStore(flake=flake_obj)
my_generator = Generator( my_generator = Generator(
"my_generator", share=False, files=[], machine="my_machine", _flake=flake_obj "my_generator",
share=False,
files=[],
machine="my_machine",
_flake=flake_obj,
) )
my_generator_shared = Generator( my_generator_shared = Generator(
"my_generator", share=True, files=[], machine="my_machine", _flake=flake_obj "my_generator",
share=True,
files=[],
machine="my_machine",
_flake=flake_obj,
) )
my_shared_generator = Generator( my_shared_generator = Generator(
"my_shared_generator", "my_shared_generator",
@@ -538,7 +571,11 @@ def test_generate_secret_var_password_store(
assert "my_generator/my_secret" in vars_text assert "my_generator/my_secret" in vars_text
my_generator = Generator( my_generator = Generator(
"my_generator", share=False, files=[], machine="my_machine", _flake=flake_obj "my_generator",
share=False,
files=[],
machine="my_machine",
_flake=flake_obj,
) )
var_name = "my_secret" var_name = "my_secret"
store.delete(my_generator, var_name) store.delete(my_generator, var_name)
@@ -547,7 +584,11 @@ def test_generate_secret_var_password_store(
store.delete_store("my_machine") store.delete_store("my_machine")
store.delete_store("my_machine") # check idempotency store.delete_store("my_machine") # check idempotency
my_generator2 = Generator( my_generator2 = Generator(
"my_generator2", share=False, files=[], machine="my_machine", _flake=flake_obj "my_generator2",
share=False,
files=[],
machine="my_machine",
_flake=flake_obj,
) )
var_name = "my_secret2" var_name = "my_secret2"
assert not store.exists(my_generator2, var_name) assert not store.exists(my_generator2, var_name)
@@ -686,9 +727,7 @@ def test_shared_vars_must_never_depend_on_machine_specific_vars(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
flake_with_sops: ClanFlake, flake_with_sops: ClanFlake,
) -> None: ) -> None:
""" """Ensure that shared vars never depend on machine specific vars."""
Ensure that shared vars never depend on machine specific vars.
"""
flake = flake_with_sops flake = flake_with_sops
config = flake.machines["my_machine"] config = flake.machines["my_machine"]
@@ -719,8 +758,7 @@ def test_multi_machine_shared_vars(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
flake_with_sops: ClanFlake, flake_with_sops: ClanFlake,
) -> None: ) -> None:
""" """Ensure that shared vars are regenerated only when they should, and also can be
Ensure that shared vars are regenerated only when they should, and also can be
accessed by all machines that should have access. accessed by all machines that should have access.
Specifically: Specifically:
@@ -752,10 +790,16 @@ def test_multi_machine_shared_vars(
in_repo_store_2 = in_repo.FactStore(machine2.flake) in_repo_store_2 = in_repo.FactStore(machine2.flake)
# Create generators with machine context for testing # Create generators with machine context for testing
generator_m1 = Generator( generator_m1 = Generator(
"shared_generator", share=True, machine="machine1", _flake=machine1.flake "shared_generator",
share=True,
machine="machine1",
_flake=machine1.flake,
) )
generator_m2 = Generator( generator_m2 = Generator(
"shared_generator", share=True, machine="machine2", _flake=machine2.flake "shared_generator",
share=True,
machine="machine2",
_flake=machine2.flake,
) )
# generate for machine 1 # generate for machine 1
cli.run(["vars", "generate", "--flake", str(flake.path), "machine1"]) cli.run(["vars", "generate", "--flake", str(flake.path), "machine1"])
@@ -771,7 +815,7 @@ def test_multi_machine_shared_vars(
# ensure shared secret stays available for all machines after regeneration # ensure shared secret stays available for all machines after regeneration
# regenerate for machine 1 # regenerate for machine 1
cli.run( cli.run(
["vars", "generate", "--flake", str(flake.path), "machine1", "--regenerate"] ["vars", "generate", "--flake", str(flake.path), "machine1", "--regenerate"],
) )
# ensure values changed # ensure values changed
new_secret_1 = sops_store_1.get(generator_m1, "my_secret") new_secret_1 = sops_store_1.get(generator_m1, "my_secret")
@@ -806,7 +850,7 @@ def test_api_set_prompts(
prompt_values={ prompt_values={
"my_generator": { "my_generator": {
"prompt1": "input1", "prompt1": "input1",
} },
}, },
) )
machine = Machine(name="my_machine", flake=Flake(str(flake.path))) machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
@@ -820,14 +864,16 @@ def test_api_set_prompts(
prompt_values={ prompt_values={
"my_generator": { "my_generator": {
"prompt1": "input2", "prompt1": "input2",
} },
}, },
) )
assert store.get(my_generator, "prompt1").decode() == "input2" assert store.get(my_generator, "prompt1").decode() == "input2"
machine = Machine(name="my_machine", flake=Flake(str(flake.path))) machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
generators = get_generators( generators = get_generators(
machine=machine, full_closure=True, include_previous_values=True machine=machine,
full_closure=True,
include_previous_values=True,
) )
# get_generators should bind the store # get_generators should bind the store
assert generators[0].files[0]._store is not None assert generators[0].files[0]._store is not None
@@ -957,7 +1003,9 @@ def test_migration(
flake_obj = Flake(str(flake.path)) flake_obj = Flake(str(flake.path))
my_generator = Generator("my_generator", machine="my_machine", _flake=flake_obj) my_generator = Generator("my_generator", machine="my_machine", _flake=flake_obj)
other_generator = Generator( other_generator = Generator(
"other_generator", machine="my_machine", _flake=flake_obj "other_generator",
machine="my_machine",
_flake=flake_obj,
) )
in_repo_store = in_repo.FactStore(flake=flake_obj) in_repo_store = in_repo.FactStore(flake=flake_obj)
sops_store = sops.SecretStore(flake=flake_obj) sops_store = sops.SecretStore(flake=flake_obj)
@@ -1023,7 +1071,8 @@ def test_fails_when_files_are_left_from_other_backend(
@pytest.mark.with_core @pytest.mark.with_core
def test_create_sops_age_secrets( def test_create_sops_age_secrets(
monkeypatch: pytest.MonkeyPatch, flake: ClanFlake monkeypatch: pytest.MonkeyPatch,
flake: ClanFlake,
) -> None: ) -> None:
monkeypatch.chdir(flake.path) monkeypatch.chdir(flake.path)
cli.run(["vars", "keygen", "--flake", str(flake.path), "--user", "user"]) cli.run(["vars", "keygen", "--flake", str(flake.path), "--user", "user"])
@@ -1111,7 +1160,7 @@ def test_dynamic_invalidation(
in { in {
clan.core.vars.generators.dependent_generator.validation = if builtins.pathExists p then builtins.readFile p else null; clan.core.vars.generators.dependent_generator.validation = if builtins.pathExists p then builtins.readFile p else null;
} }
""" """,
) )
flake.refresh() flake.refresh()

View File

@@ -29,30 +29,30 @@ def test_vm_deployment(
nix_eval( nix_eval(
[ [
f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.sops.secrets", f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.sops.secrets",
] ],
) ),
).stdout.strip() ).stdout.strip(),
) )
assert sops_secrets != {} assert sops_secrets != {}
my_secret_path = run( my_secret_path = run(
nix_eval( nix_eval(
[ [
f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.clan.core.vars.generators.m1_generator.files.my_secret.path", f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.clan.core.vars.generators.m1_generator.files.my_secret.path",
] ],
) ),
).stdout.strip() ).stdout.strip()
assert "no-such-path" not in my_secret_path assert "no-such-path" not in my_secret_path
shared_secret_path = run( shared_secret_path = run(
nix_eval( nix_eval(
[ [
f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.clan.core.vars.generators.my_shared_generator.files.shared_secret.path", f"{vm_test_flake}#nixosConfigurations.test-vm-deployment.config.clan.core.vars.generators.my_shared_generator.files.shared_secret.path",
] ],
) ),
).stdout.strip() ).stdout.strip()
assert "no-such-path" not in shared_secret_path assert "no-such-path" not in shared_secret_path
vm1_config = inspect_vm( vm1_config = inspect_vm(
machine=Machine("test-vm-deployment", Flake(str(vm_test_flake))) machine=Machine("test-vm-deployment", Flake(str(vm_test_flake))),
) )
with ExitStack() as stack: with ExitStack() as stack:
vm1 = stack.enter_context(spawn_vm(vm1_config, stdin=subprocess.DEVNULL)) vm1 = stack.enter_context(spawn_vm(vm1_config, stdin=subprocess.DEVNULL))
@@ -64,7 +64,7 @@ def test_vm_deployment(
assert result.stdout == "hello\n" assert result.stdout == "hello\n"
# check shared_secret is deployed # check shared_secret is deployed
result = qga_m1.run( result = qga_m1.run(
["cat", "/run/secrets/vars/my_shared_generator/shared_secret"] ["cat", "/run/secrets/vars/my_shared_generator/shared_secret"],
) )
assert result.stdout == "hello\n" assert result.stdout == "hello\n"
# check no_deploy_secret is not deployed # check no_deploy_secret is not deployed

View File

@@ -17,7 +17,8 @@ no_kvm = not Path("/dev/kvm").exists()
@pytest.mark.with_core @pytest.mark.with_core
def test_inspect( def test_inspect(
test_flake_with_core: FlakeForTest, capture_output: CaptureOutput test_flake_with_core: FlakeForTest,
capture_output: CaptureOutput,
) -> None: ) -> None:
with capture_output as output: with capture_output as output:
cli.run(["vms", "inspect", "--flake", str(test_flake_with_core.path), "vm1"]) cli.run(["vms", "inspect", "--flake", str(test_flake_with_core.path), "vm1"])
@@ -42,7 +43,7 @@ def test_run(
"add", "add",
"user1", "user1",
age_keys[0].pubkey, age_keys[0].pubkey,
] ],
) )
cli.run( cli.run(
[ [
@@ -51,7 +52,7 @@ def test_run(
"add-user", "add-user",
"admins", "admins",
"user1", "user1",
] ],
) )
cli.run( cli.run(
[ [
@@ -63,7 +64,7 @@ def test_run(
"shutdown", "shutdown",
"-h", "-h",
"now", "now",
] ],
) )
@@ -74,7 +75,7 @@ def test_vm_persistence(
) -> None: ) -> None:
# Use the pre-built test VM from the test flake # Use the pre-built test VM from the test flake
vm_config = inspect_vm( vm_config = inspect_vm(
machine=Machine("test-vm-persistence", Flake(str(vm_test_flake))) machine=Machine("test-vm-persistence", Flake(str(vm_test_flake))),
) )
with spawn_vm(vm_config) as vm, vm.qga_connect() as qga: with spawn_vm(vm_config) as vm, vm.qga_connect() as qga:

View File

@@ -62,9 +62,7 @@ class StoreBase(ABC):
var: "Var", var: "Var",
value: bytes, value: bytes,
) -> Path | None: ) -> Path | None:
""" """Override this method to implement the actual creation of the file"""
override this method to implement the actual creation of the file
"""
@abstractmethod @abstractmethod
def exists(self, generator: "Generator", name: str) -> bool: def exists(self, generator: "Generator", name: str) -> bool:
@@ -81,8 +79,7 @@ class StoreBase(ABC):
generators: list["Generator"] | None = None, generators: list["Generator"] | None = None,
file_name: str | None = None, file_name: str | None = None,
) -> str | None: ) -> str | None:
""" """Check the health of the store for the given machine and generators.
Check the health of the store for the given machine and generators.
This method detects any issues or inconsistencies in the store that may This method detects any issues or inconsistencies in the store that may
require fixing (e.g., outdated encryption keys, missing permissions). require fixing (e.g., outdated encryption keys, missing permissions).
@@ -94,6 +91,7 @@ class StoreBase(ABC):
Returns: Returns:
str | None: An error message describing issues found, or None if everything is healthy str | None: An error message describing issues found, or None if everything is healthy
""" """
return None return None
@@ -103,8 +101,7 @@ class StoreBase(ABC):
generators: list["Generator"] | None = None, generators: list["Generator"] | None = None,
file_name: str | None = None, file_name: str | None = None,
) -> None: ) -> None:
""" """Fix any issues with the store for the given machine and generators.
Fix any issues with the store for the given machine and generators.
This method is intended to repair or update the store when inconsistencies This method is intended to repair or update the store when inconsistencies
are detected (e.g., re-encrypting secrets with new keys, fixing permissions). are detected (e.g., re-encrypting secrets with new keys, fixing permissions).
@@ -116,6 +113,7 @@ class StoreBase(ABC):
Returns: Returns:
None None
""" """
return return
@@ -164,16 +162,15 @@ class StoreBase(ABC):
log_info = machine.info log_info = machine.info
if self.is_secret_store: if self.is_secret_store:
log.info(f"{action_str} secret var {generator.name}/{var.name}\n") log.info(f"{action_str} secret var {generator.name}/{var.name}\n")
elif value != old_val:
msg = f"{action_str} var {generator.name}/{var.name}"
if not is_migration:
msg += f"\n old: {old_val_str}\n new: {string_repr(value)}"
log_info(msg)
else: else:
if value != old_val: log_info(
msg = f"{action_str} var {generator.name}/{var.name}" f"Var {generator.name}/{var.name} remains unchanged: {old_val_str}",
if not is_migration: )
msg += f"\n old: {old_val_str}\n new: {string_repr(value)}"
log_info(msg)
else:
log_info(
f"Var {generator.name}/{var.name} remains unchanged: {old_val_str}"
)
return new_file return new_file
@abstractmethod @abstractmethod
@@ -200,8 +197,7 @@ class StoreBase(ABC):
""" """
def get_validation(self, generator: "Generator") -> str | None: def get_validation(self, generator: "Generator") -> str | None:
""" """Return the invalidation hash that indicates if a generator needs to be re-run
Return the invalidation hash that indicates if a generator needs to be re-run
due to a change in its definition due to a change in its definition
""" """
hash_file = self.directory(generator, ".validation-hash") hash_file = self.directory(generator, ".validation-hash")
@@ -210,17 +206,14 @@ class StoreBase(ABC):
return hash_file.read_text().strip() return hash_file.read_text().strip()
def set_validation(self, generator: "Generator", hash_str: str) -> Path: def set_validation(self, generator: "Generator", hash_str: str) -> Path:
""" """Store the invalidation hash that indicates if a generator needs to be re-run"""
Store the invalidation hash that indicates if a generator needs to be re-run
"""
hash_file = self.directory(generator, ".validation-hash") hash_file = self.directory(generator, ".validation-hash")
hash_file.parent.mkdir(parents=True, exist_ok=True) hash_file.parent.mkdir(parents=True, exist_ok=True)
hash_file.write_text(hash_str) hash_file.write_text(hash_str)
return hash_file return hash_file
def hash_is_valid(self, generator: "Generator") -> bool: def hash_is_valid(self, generator: "Generator") -> bool:
""" """Check if the invalidation hash is up to date
Check if the invalidation hash is up to date
If the hash is not set in nix and hasn't been stored before, it is considered valid If the hash is not set in nix and hasn't been stored before, it is considered valid
-> this provides backward and forward compatibility -> this provides backward and forward compatibility
""" """

View File

@@ -28,7 +28,9 @@ class VarStatus:
def vars_status( def vars_status(
machine_name: str, flake: Flake, generator_name: None | str = None machine_name: str,
flake: Flake,
generator_name: None | str = None,
) -> VarStatus: ) -> VarStatus:
machine = Machine(name=machine_name, flake=flake) machine = Machine(name=machine_name, flake=flake)
missing_secret_vars = [] missing_secret_vars = []
@@ -53,14 +55,14 @@ def vars_status(
for generator in generators: for generator in generators:
for file in generator.files: for file in generator.files:
file.store( file.store(
machine.secret_vars_store if file.secret else machine.public_vars_store machine.secret_vars_store if file.secret else machine.public_vars_store,
) )
file.generator(generator) file.generator(generator)
if file.secret: if file.secret:
if not machine.secret_vars_store.exists(generator, file.name): if not machine.secret_vars_store.exists(generator, file.name):
machine.info( machine.info(
f"Secret var '{file.name}' for service '{generator.name}' in machine {machine.name} is missing." f"Secret var '{file.name}' for service '{generator.name}' in machine {machine.name} is missing.",
) )
missing_secret_vars.append(file) missing_secret_vars.append(file)
else: else:
@@ -71,13 +73,13 @@ def vars_status(
) )
if msg: if msg:
machine.info( machine.info(
f"Secret var '{file.name}' for service '{generator.name}' in machine {machine.name} needs update: {msg}" f"Secret var '{file.name}' for service '{generator.name}' in machine {machine.name} needs update: {msg}",
) )
unfixed_secret_vars.append(file) unfixed_secret_vars.append(file)
elif not machine.public_vars_store.exists(generator, file.name): elif not machine.public_vars_store.exists(generator, file.name):
machine.info( machine.info(
f"Public var '{file.name}' for service '{generator.name}' in machine {machine.name} is missing." f"Public var '{file.name}' for service '{generator.name}' in machine {machine.name} is missing.",
) )
missing_public_vars.append(file) missing_public_vars.append(file)
# check if invalidation hash is up to date # check if invalidation hash is up to date
@@ -87,7 +89,7 @@ def vars_status(
): ):
invalid_generators.append(generator.name) invalid_generators.append(generator.name)
machine.info( machine.info(
f"Generator '{generator.name}' in machine {machine.name} has outdated invalidation hash." f"Generator '{generator.name}' in machine {machine.name} has outdated invalidation hash.",
) )
return VarStatus( return VarStatus(
missing_secret_vars, missing_secret_vars,
@@ -98,7 +100,9 @@ def vars_status(
def check_vars( def check_vars(
machine_name: str, flake: Flake, generator_name: None | str = None machine_name: str,
flake: Flake,
generator_name: None | str = None,
) -> bool: ) -> bool:
status = vars_status(machine_name, flake, generator_name=generator_name) status = vars_status(machine_name, flake, generator_name=generator_name)
return not ( return not (

View File

@@ -6,7 +6,8 @@ from clan_lib.errors import ClanError
def test_check_command_no_flake( def test_check_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -21,7 +21,7 @@ def generate_command(args: argparse.Namespace) -> None:
filter( filter(
lambda m: m.name in args.machines, lambda m: m.name in args.machines,
machines, machines,
) ),
) )
# prefetch all vars # prefetch all vars
@@ -32,7 +32,7 @@ def generate_command(args: argparse.Namespace) -> None:
flake.precache( flake.precache(
[ [
f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.core.vars.generators.*.validationHash", f"clanInternals.machines.{system}.{{{','.join(machine_names)}}}.config.clan.core.vars.generators.*.validationHash",
] ],
) )
run_generators( run_generators(

View File

@@ -6,7 +6,8 @@ from clan_lib.errors import ClanError
def test_generate_command_no_flake( def test_generate_command_no_flake(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)

View File

@@ -19,7 +19,8 @@ log = logging.getLogger(__name__)
def dependencies_as_dir( def dependencies_as_dir(
decrypted_dependencies: dict[str, dict[str, bytes]], tmpdir: Path decrypted_dependencies: dict[str, dict[str, bytes]],
tmpdir: Path,
) -> None: ) -> None:
"""Helper function to create directory structure from decrypted dependencies.""" """Helper function to create directory structure from decrypted dependencies."""
for dep_generator, files in decrypted_dependencies.items(): for dep_generator, files in decrypted_dependencies.items():
@@ -72,13 +73,15 @@ class Generator:
flake: "Flake", flake: "Flake",
include_previous_values: bool = False, include_previous_values: bool = False,
) -> list["Generator"]: ) -> list["Generator"]:
""" """Get all generators for a machine from the flake.
Get all generators for a machine from the flake.
Args: Args:
machine_name (str): The name of the machine. machine_name (str): The name of the machine.
flake (Flake): The flake to get the generators from. flake (Flake): The flake to get the generators from.
Returns: Returns:
list[Generator]: A list of (unsorted) generators for the machine. list[Generator]: A list of (unsorted) generators for the machine.
""" """
# Get all generator metadata in one select (safe fields only) # Get all generator metadata in one select (safe fields only)
generators_data = flake.select_machine( generators_data = flake.select_machine(
@@ -146,7 +149,8 @@ class Generator:
for generator in generators: for generator in generators:
for prompt in generator.prompts: for prompt in generator.prompts:
prompt.previous_value = generator.get_previous_value( prompt.previous_value = generator.get_previous_value(
machine, prompt machine,
prompt,
) )
return generators return generators
@@ -175,8 +179,8 @@ class Generator:
machine = Machine(name=self.machine, flake=self._flake) machine = Machine(name=self.machine, flake=self._flake)
output = Path( output = Path(
machine.select( machine.select(
f'config.clan.core.vars.generators."{self.name}".finalScript' f'config.clan.core.vars.generators."{self.name}".finalScript',
) ),
) )
if tmp_store := nix_test_store(): if tmp_store := nix_test_store():
output = tmp_store.joinpath(*output.parts[1:]) output = tmp_store.joinpath(*output.parts[1:])
@@ -189,7 +193,7 @@ class Generator:
machine = Machine(name=self.machine, flake=self._flake) machine = Machine(name=self.machine, flake=self._flake)
return machine.select( return machine.select(
f'config.clan.core.vars.generators."{self.name}".validationHash' f'config.clan.core.vars.generators."{self.name}".validationHash',
) )
def decrypt_dependencies( def decrypt_dependencies(
@@ -207,6 +211,7 @@ class Generator:
Returns: Returns:
Dictionary mapping generator names to their variable values Dictionary mapping generator names to their variable values
""" """
from clan_lib.errors import ClanError from clan_lib.errors import ClanError
@@ -222,7 +227,8 @@ class Generator:
result[dep_key.name] = {} result[dep_key.name] = {}
dep_generator = next( dep_generator = next(
(g for g in generators if g.name == dep_key.name), None (g for g in generators if g.name == dep_key.name),
None,
) )
if dep_generator is None: if dep_generator is None:
msg = f"Generator {dep_key.name} not found in machine {machine.name}" msg = f"Generator {dep_key.name} not found in machine {machine.name}"
@@ -237,11 +243,13 @@ class Generator:
for file in dep_files: for file in dep_files:
if file.secret: if file.secret:
result[dep_key.name][file.name] = secret_vars_store.get( result[dep_key.name][file.name] = secret_vars_store.get(
dep_generator, file.name dep_generator,
file.name,
) )
else: else:
result[dep_key.name][file.name] = public_vars_store.get( result[dep_key.name][file.name] = public_vars_store.get(
dep_generator, file.name dep_generator,
file.name,
) )
return result return result
@@ -250,6 +258,7 @@ class Generator:
Returns: Returns:
Dictionary mapping prompt names to their values Dictionary mapping prompt names to their values
""" """
from .prompt import ask from .prompt import ask
@@ -275,6 +284,7 @@ class Generator:
machine: The machine to execute the generator for machine: The machine to execute the generator for
prompt_values: Optional dictionary of prompt values. If not provided, prompts will be asked interactively. prompt_values: Optional dictionary of prompt values. If not provided, prompts will be asked interactively.
no_sandbox: Whether to disable sandboxing when executing the generator no_sandbox: Whether to disable sandboxing when executing the generator
""" """
import os import os
import sys import sys
@@ -333,8 +343,8 @@ class Generator:
"--uid", "1000", "--uid", "1000",
"--gid", "1000", "--gid", "1000",
"--", "--",
str(real_bash_path), "-c", generator str(real_bash_path), "-c", generator,
] ],
) )
# fmt: on # fmt: on
@@ -418,11 +428,11 @@ class Generator:
if validation is not None: if validation is not None:
if public_changed: if public_changed:
files_to_commit.append( files_to_commit.append(
machine.public_vars_store.set_validation(self, validation) machine.public_vars_store.set_validation(self, validation),
) )
if secret_changed: if secret_changed:
files_to_commit.append( files_to_commit.append(
machine.secret_vars_store.set_validation(self, validation) machine.secret_vars_store.set_validation(self, validation),
) )
commit_files( commit_files(

View File

@@ -33,7 +33,7 @@ def get_machine_var(machine: Machine, var_id: str) -> Var:
raise ClanError(msg) raise ClanError(msg)
if len(results) > 1: if len(results) > 1:
error = f"Found multiple vars for {var_id}:\n - " + "\n - ".join( error = f"Found multiple vars for {var_id}:\n - " + "\n - ".join(
[str(var) for var in results] [str(var) for var in results],
) )
raise ClanError(error) raise ClanError(error)
# we have exactly one result at this point # we have exactly one result at this point

Some files were not shown because too many files have changed in this diff Show More