From 65a5789c5b3b4fea89221a52016240b8acda9d6f Mon Sep 17 00:00:00 2001 From: Qubasa Date: Mon, 9 Dec 2024 18:07:23 +0100 Subject: [PATCH] clan-cli: Replace HostGroup and MachineGroup with generic AsyncRuntime class. Propagate cmd prefix over thread local. Close threads on CTRL+C --- pkgs/clan-cli/clan_cli/__init__.py | 4 +- pkgs/clan-cli/clan_cli/async_run.py | 312 ++++++++++++++++++ pkgs/clan-cli/clan_cli/cmd.py | 36 +- pkgs/clan-cli/clan_cli/custom_logger.py | 3 + .../clan_cli/machines/machine_group.py | 35 -- pkgs/clan-cli/clan_cli/machines/update.py | 128 +++---- pkgs/clan-cli/clan_cli/nix/__init__.py | 2 +- pkgs/clan-cli/clan_cli/ssh/host_group.py | 214 ------------ pkgs/clan-cli/clan_cli/vars/generate.py | 3 +- pkgs/clan-cli/tests/conftest.py | 3 +- pkgs/clan-cli/tests/host_group.py | 25 -- pkgs/clan-cli/tests/hosts.py | 23 ++ pkgs/clan-cli/tests/runtime.py | 7 + .../tests/test_api_dataclass_compat.py | 1 + .../tests/test_secrets_password_store.py | 6 +- pkgs/clan-cli/tests/test_secrets_upload.py | 6 +- pkgs/clan-cli/tests/test_ssh_local.py | 83 ++--- pkgs/clan-cli/tests/test_ssh_remote.py | 94 ++++-- 18 files changed, 549 insertions(+), 436 deletions(-) create mode 100644 pkgs/clan-cli/clan_cli/async_run.py delete mode 100644 pkgs/clan-cli/clan_cli/machines/machine_group.py delete mode 100644 pkgs/clan-cli/clan_cli/ssh/host_group.py delete mode 100644 pkgs/clan-cli/tests/host_group.py create mode 100644 pkgs/clan-cli/tests/hosts.py create mode 100644 pkgs/clan-cli/tests/runtime.py diff --git a/pkgs/clan-cli/clan_cli/__init__.py b/pkgs/clan-cli/clan_cli/__init__.py index 27aeb0301..aa9ffd363 100644 --- a/pkgs/clan-cli/clan_cli/__init__.py +++ b/pkgs/clan-cli/clan_cli/__init__.py @@ -434,8 +434,8 @@ def main() -> None: else: log.error("%s", e) # noqa: TRY400 sys.exit(1) - except KeyboardInterrupt: - log.warning("Interrupted by user") + except KeyboardInterrupt as ex: + log.warning("Interrupted by user", exc_info=ex) sys.exit(1) diff --git a/pkgs/clan-cli/clan_cli/async_run.py b/pkgs/clan-cli/clan_cli/async_run.py new file mode 100644 index 000000000..59d9f0837 --- /dev/null +++ b/pkgs/clan-cli/clan_cli/async_run.py @@ -0,0 +1,312 @@ +import logging +import threading +import time +import types +import uuid +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import IO, Any, Generic, ParamSpec, TypeVar + +from clan_cli.errors import ClanError + +log = logging.getLogger(__name__) + +# Why did we create a custom AsyncRuntime instead of using asyncio? +# +# The AsyncRuntime class allows us to run functions in separate threads for asynchronous +# execution without requiring the use of the "async" keyword. By using this approach, +# functions can gracefully handle cancellation by checking if get_async_ctx().cancel is True. +# +# There was some resistance to using asyncio, partly due to challenges we faced when +# implementing it in our first web interface. Threads felt simpler and more familiar +# for our use case. +# +# That said, asyncio is generally more efficient because it uses non-blocking I/O +# and avoids issues with Python's Global Interpreter Lock (GIL), which can limit the +# performance of threads in CPU-heavy workloads. +# +# Using threads works well for us because most of the time is spent waiting for commands +# or external processes to finish, rather than performing heavy computing tasks. +# +# Note: Starting with Python 3.14, the GIL can be disabled to enable true parallelism. +# However, disabling the GIL introduces a 10-40% performance cost for Python code +# due to the overhead of additional locking. + +# Define generics for return type and call signature +R = TypeVar("R") # Return type of the callable +P = ParamSpec("P") # Parameters of the callable + + +@dataclass +class AsyncResult(Generic[R]): + _result: R | Exception + + @property + def error(self) -> Exception | None: + """ + Returns an error if the callable raised an exception. + """ + if isinstance(self._result, Exception): + return self._result + return None + + @property + def result(self) -> R: + """ + Unwraps and returns the result if no exception occurred. + Raises the exception otherwise. + """ + if isinstance(self._result, Exception): + raise self._result + return self._result + + +@dataclass +class AsyncContext: + """ + This class stores thread-local data. + """ + + prefix: str | None = None # prefix for logging + stdout: IO[bytes] | None = None # stdout of subprocesses + stderr: IO[bytes] | None = None # stderr of subprocesses + cancel: bool = False # Used to signal cancellation of task + + +@dataclass +class AsyncOpts: + """ + Options for the async_run function. + """ + + tid: str | None = None + check: bool = True + async_ctx: AsyncContext = field(default_factory=AsyncContext) + + +ASYNC_CTX_THREAD_LOCAL = threading.local() + + +def is_async_cancelled() -> bool: + """ + Check if the current task has been cancelled. + """ + return get_async_ctx().cancel + + +def get_async_ctx() -> AsyncContext: + """ + Retrieve the current AsyncContext, creating a new one if none exists. + """ + global ASYNC_CTX_THREAD_LOCAL + + if not hasattr(ASYNC_CTX_THREAD_LOCAL, "async_ctx"): + ASYNC_CTX_THREAD_LOCAL.async_ctx = AsyncContext() + return ASYNC_CTX_THREAD_LOCAL.async_ctx + + +def set_async_ctx(ctx: AsyncContext) -> None: + global ASYNC_CTX_THREAD_LOCAL + ASYNC_CTX_THREAD_LOCAL.async_ctx = ctx + + +class AsyncThread(threading.Thread, Generic[P, R]): + function: Callable[P, R] + args: Any + kwargs: Any + result: AsyncResult[R] | None + finished: bool + condition: threading.Condition + async_opts: AsyncOpts + + def __init__( + self, + async_opts: AsyncOpts, + condition: threading.Condition, + function: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """ + A threaded wrapper for running a function asynchronously. + """ + super().__init__() + self.function = function + self.args = args + self.kwargs = kwargs + self.result: AsyncResult[R] | None = None # Store the result or exception + self.finished = False # Set to True after the thread finishes execution + self.condition = condition # Shared condition variable + self.async_opts = async_opts + + def run(self) -> None: + """ + Run the function in a separate thread. + """ + try: + set_async_ctx(self.async_opts.async_ctx) + self.result = AsyncResult(_result=self.function(*self.args, **self.kwargs)) + except Exception as ex: + self.result = AsyncResult(_result=ex) + finally: + self.finished = True + # Acquire the condition lock before notifying + with self.condition: + self.condition.notify_all() # Notify waiting threads that this thread is done + + +@dataclass +class AsyncFuture(Generic[R]): + _tid: str + _runtime: "AsyncRuntime" + + def wait(self) -> AsyncResult[R]: + """ + Wait for the task to finish. + """ + if self._tid not in self._runtime.tasks: + msg = f"No task with the name '{self._tid}' exists." + raise ClanError(msg) + + thread = self._runtime.tasks[self._tid] + thread.join() + result = self.get_result() + if result is None: + msg = f"Task '{self._tid}' unexpectedly returned None." + raise ClanError(msg) + return result + + def get_result(self) -> AsyncResult[R] | None: + """ + Retrieve the result of a finished task and remove it from the task list. + """ + if self._tid not in self._runtime.tasks: + msg = f"No task with the name '{self._tid}' exists." + raise ClanError(msg) + + thread = self._runtime.tasks[self._tid] + + if not thread.finished: + return None + + # Remove the task after retrieving the result + result = thread.result + del self._runtime.tasks[self._tid] + + if result is None: + msg = f"The result for task '{self._tid}' is unexpectedly None." + raise ClanError(msg) + + return result + + +@dataclass +class AsyncRuntime: + tasks: dict[str, AsyncThread[Any, Any]] = field(default_factory=dict) + condition: threading.Condition = field(default_factory=threading.Condition) + + def async_run( + self, + opts: AsyncOpts | None, + function: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> AsyncFuture[R]: + """ + Run the given function asynchronously in a thread with a specific name and arguments. + The function's static typing is preserved. + """ + if opts is None: + opts = AsyncOpts() + + if opts.tid is None: + opts.tid = uuid.uuid4().hex + + if opts.tid in self.tasks: + msg = f"A task with the name '{opts.tid}' is already running." + raise ClanError(msg) + + # Create and start the new AsyncThread + thread = AsyncThread(opts, self.condition, function, *args, **kwargs) + self.tasks[opts.tid] = thread + thread.start() + return AsyncFuture(opts.tid, self) + + def join_all(self) -> None: + """ + Wait for all tasks to finish + """ + with self.condition: + while any( + not task.finished for task in self.tasks.values() + ): # Check if any tasks are still running + self.condition.wait() # Wait until a thread signals completion + + def check_all(self) -> None: + """ + Check if there where any errors + """ + err_count = 0 + + for name, task in self.tasks.items(): + if task.finished and task.async_opts.check: + assert task.result is not None + error = task.result.error + if log.isEnabledFor(logging.DEBUG): + log.error( + f"failed with error: {error}", + extra={"command_prefix": name}, + exc_info=error, + ) + else: + log.error( + f"failed with error: {error}", extra={"command_prefix": name} + ) + err_count += 1 + + if err_count > 0: + msg = f"{err_count} hosts failed with an error. Check the logs above" + raise ClanError(msg) + + def __enter__(self) -> "AsyncRuntime": + """ + Enter the runtime context related to this object. + """ + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + """ + Exit the runtime context related to this object. + Sets async_ctx.cancel to True to signal cancellation. + """ + for name, task in self.tasks.items(): + if not task.finished: + task.async_opts.async_ctx.cancel = True + log.debug(f"Canceling task {name}") + + +# Example usage +if __name__ == "__main__": + runtime = AsyncRuntime() + + def add(a: int, b: int) -> int: + return a + b + + def concatenate(a: str, b: str) -> str: + time.sleep(1) + msg = "Hello World" + raise ClanError(msg) + + with runtime: + p1 = runtime.async_run(None, add, 1, 2) + p2 = runtime.async_run(None, concatenate, "Hello ", "World") + + add_result = p1.wait() + print(add_result.result) # Output: 3 + concat_result = p2.wait() + print(concat_result.error) # Output: Hello World diff --git a/pkgs/clan-cli/clan_cli/cmd.py b/pkgs/clan-cli/clan_cli/cmd.py index 2db7b71c3..cc0fc7c15 100644 --- a/pkgs/clan-cli/clan_cli/cmd.py +++ b/pkgs/clan-cli/clan_cli/cmd.py @@ -17,6 +17,7 @@ from enum import Enum from pathlib import Path from typing import IO, Any +from clan_cli.async_run import get_async_ctx, is_async_cancelled from clan_cli.colors import Color from clan_cli.custom_logger import print_trace from clan_cli.errors import ClanCmdError, ClanError, CmdOut, indent_command @@ -87,7 +88,8 @@ def handle_io( stdout_extra = {} stderr_extra = {} if prefix: - stdout_extra["command_prefix"] = stderr_extra["command_prefix"] = prefix + stdout_extra["command_prefix"] = prefix + stderr_extra["command_prefix"] = prefix if msg_color and msg_color.stderr: stdout_extra["color"] = msg_color.stderr.value if msg_color and msg_color.stdout: @@ -101,6 +103,13 @@ def handle_io( description = prefix raise ClanCmdTimeoutError(msg=msg, description=description, timeout=timeout) + # Check if the command has been cancelled + if is_async_cancelled(): + cmdlog.warning("Command cancelled", extra=stderr_extra) + + # Terminate process + break + # Wait for data to be available readlist, writelist, _ = select.select(rlist, wlist, [], 0.1) if len(readlist) == 0 and len(writelist) == 0: @@ -116,7 +125,7 @@ def handle_io( # If Log.STDOUT is set, log the stdout output if ret and log in [Log.STDOUT, Log.BOTH]: - lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n") + lines = ret.decode("utf-8", "replace").rstrip("\n").rstrip().split("\n") for line in lines: cmdlog.info(line, extra=stdout_extra) @@ -133,7 +142,7 @@ def handle_io( # If Log.STDERR is set, log the stderr output if ret and log in [Log.STDERR, Log.BOTH]: - lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n") + lines = ret.decode("utf-8", "replace").rstrip("\n").rstrip().split("\n") for line in lines: cmdlog.info(line, extra=stderr_extra) @@ -273,6 +282,17 @@ def run( if options.cwd is None: options.cwd = Path.cwd() + async_ctx = get_async_ctx() + # Fill in the options from the thread-local data + # if they are not set in the options + if async_ctx: + if options.prefix is None: + options.prefix = async_ctx.prefix + if options.stdout is None: + options.stdout = async_ctx.stdout + if options.stderr is None: + options.stderr = async_ctx.stderr + if options.input: if any(not ch.isprintable() for ch in options.input.decode("ascii", "replace")): filtered_input = "<>" @@ -318,7 +338,8 @@ def run( stdout=options.stdout, stderr=options.stderr, ) - process.wait() + if not is_async_cancelled(): + process.wait() global TIME_TABLE if TIME_TABLE: @@ -336,7 +357,9 @@ def run( ) if options.check and process.returncode != 0: - raise ClanCmdError(cmd_out) + err = ClanCmdError(cmd_out) + err.msg = "Command has been cancelled" + raise err return cmd_out @@ -359,3 +382,6 @@ def run_no_stdout( cmd, opts, ) + + +# type: ignore diff --git a/pkgs/clan-cli/clan_cli/custom_logger.py b/pkgs/clan-cli/clan_cli/custom_logger.py index b56af571a..4524d78f5 100644 --- a/pkgs/clan-cli/clan_cli/custom_logger.py +++ b/pkgs/clan-cli/clan_cli/custom_logger.py @@ -59,6 +59,9 @@ class PrefixFormatter(logging.Formatter): # If command_prefix is set, color the prefix with a unique color. elif command_prefix: + command_prefix = command_prefix[ + :10 + ] # Truncate the command prefix to 10 characters. prefix_color = self.hostname_colorcode(command_prefix) format_str = color_by_tuple(f"[{command_prefix}]", fg=prefix_color) format_str += color_by_tuple(" %(message)s", fg=msg_color) diff --git a/pkgs/clan-cli/clan_cli/machines/machine_group.py b/pkgs/clan-cli/clan_cli/machines/machine_group.py deleted file mode 100644 index 4dc7c9a89..000000000 --- a/pkgs/clan-cli/clan_cli/machines/machine_group.py +++ /dev/null @@ -1,35 +0,0 @@ -from collections.abc import Callable -from typing import TypeVar - -from clan_cli.ssh.host import Host -from clan_cli.ssh.host_group import HostGroup, HostResult - -from .machines import Machine - -T = TypeVar("T") - - -class MachineGroup: - def __init__(self, machines: list[Machine]) -> None: - self.machines = machines - self.group = HostGroup([m.target_host for m in machines]) - - def __repr__(self) -> str: - return str(self) - - def __str__(self) -> str: - return f"MachineGroup({self.group})" - - def run_function( - self, func: Callable[[Machine], T], check: bool = True - ) -> list[HostResult[T]]: - """ - Function to run for each host in the group in parallel - - @func the function to call - """ - - def wrapped_func(host: Host) -> T: - return func(host.meta["machine"]) - - return self.group.run_function(wrapped_func, check=check) diff --git a/pkgs/clan-cli/clan_cli/machines/update.py b/pkgs/clan-cli/clan_cli/machines/update.py index d5d92768b..d546774d9 100644 --- a/pkgs/clan-cli/clan_cli/machines/update.py +++ b/pkgs/clan-cli/clan_cli/machines/update.py @@ -6,6 +6,7 @@ import shlex import sys from clan_cli.api import API +from clan_cli.async_run import AsyncContext, AsyncOpts, AsyncRuntime, is_async_cancelled from clan_cli.clan_uri import FlakeId from clan_cli.cmd import MsgColor, RunOpts, run from clan_cli.colors import AnsiColor @@ -24,7 +25,6 @@ from clan_cli.vars.generate import generate_vars from clan_cli.vars.upload import upload_secret_vars from .inventory import get_all_machines, get_selected_machines -from .machine_group import MachineGroup log = logging.getLogger(__name__) @@ -113,10 +113,10 @@ def update_machines(base_path: str, machines: list[InventoryMachine]) -> None: # m.override_build_host = machine.deploy.buildHost group_machines.append(m) - deploy_machine(MachineGroup(group_machines)) + deploy_machine(group_machines) -def deploy_machine(machines: MachineGroup) -> None: +def deploy_machine(machines: list[Machine]) -> None: """ Deploy to all hosts in parallel """ @@ -163,11 +163,9 @@ def deploy_machine(machines: MachineGroup) -> None: RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)), extra_env=env, ) - ret = host.run( - switch_cmd, - RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)), - extra_env=env, - ) + + if is_async_cancelled(): + return # if the machine is mobile, we retry to deploy with the mobile workaround method is_mobile = machine.deployment.get("nixosMobileWorkaround", False) @@ -187,62 +185,72 @@ def deploy_machine(machines: MachineGroup) -> None: extra_env=env, ) - if len(machines.group.hosts) > 1: - machines.run_function(deploy) - else: - deploy(machines.machines[0]) + with AsyncRuntime() as runtime: + for machine in machines: + machine.info(f"Updating {machine.name}") + runtime.async_run( + AsyncOpts( + tid=machine.name, async_ctx=AsyncContext(prefix=machine.name) + ), + deploy, + machine, + ) + runtime.join_all() -def update(args: argparse.Namespace) -> None: - if args.flake is None: - msg = "Could not find clan flake toplevel directory" - raise ClanError(msg) - machines = [] - if len(args.machines) == 1 and args.target_host is not None: - machine = Machine( - name=args.machines[0], flake=args.flake, nix_options=args.option - ) - machine.override_target_host = args.target_host - machine.override_build_host = args.build_host - machine.host_key_check = HostKeyCheck.from_str(args.host_key_check) - machines.append(machine) - - elif args.target_host is not None: - print("target host can only be specified for a single machine") - exit(1) - else: - if len(args.machines) == 0: - ignored_machines = [] - for machine in get_all_machines(args.flake, args.option): - if machine.deployment.get("requireExplicitUpdate", False): - continue - try: - machine.build_host # noqa: B018 - except ClanError: # check if we have a build host set - ignored_machines.append(machine) - continue - machine.host_key_check = HostKeyCheck.from_str(args.host_key_check) - machine.override_build_host = args.build_host - machines.append(machine) - - if not machines and ignored_machines != []: - print( - "WARNING: No machines to update." - "The following defined machines were ignored because they" - "do not have the `clan.core.networking.targetHost` nixos option set:", - file=sys.stderr, - ) - for machine in ignored_machines: - print(machine, file=sys.stderr) +def update_command(args: argparse.Namespace) -> None: + try: + if args.flake is None: + msg = "Could not find clan flake toplevel directory" + raise ClanError(msg) + machines = [] + if len(args.machines) == 1 and args.target_host is not None: + machine = Machine( + name=args.machines[0], flake=args.flake, nix_options=args.option + ) + machine.override_target_host = args.target_host + machine.override_build_host = args.build_host + machine.host_key_check = HostKeyCheck.from_str(args.host_key_check) + machines.append(machine) + elif args.target_host is not None: + print("target host can only be specified for a single machine") + exit(1) else: - machines = get_selected_machines(args.flake, args.option, args.machines) - for machine in machines: - machine.override_build_host = args.build_host - machine.host_key_check = HostKeyCheck.from_str(args.host_key_check) + if len(args.machines) == 0: + ignored_machines = [] + for machine in get_all_machines(args.flake, args.option): + if machine.deployment.get("requireExplicitUpdate", False): + continue + try: + machine.build_host # noqa: B018 + except ClanError: # check if we have a build host set + ignored_machines.append(machine) + continue + machine.host_key_check = HostKeyCheck.from_str(args.host_key_check) + machine.override_build_host = args.build_host + machines.append(machine) - host_group = MachineGroup(machines) - deploy_machine(host_group) + if not machines and ignored_machines != []: + print( + "WARNING: No machines to update." + "The following defined machines were ignored because they" + "do not have the `clan.core.networking.targetHost` nixos option set:", + file=sys.stderr, + ) + for machine in ignored_machines: + print(machine, file=sys.stderr) + + else: + machines = get_selected_machines(args.flake, args.option, args.machines) + for machine in machines: + machine.override_build_host = args.build_host + machine.host_key_check = HostKeyCheck.from_str(args.host_key_check) + + deploy_machine(machines) + except KeyboardInterrupt: + log.warning("Interrupted by user") + sys.exit(1) def register_update_parser(parser: argparse.ArgumentParser) -> None: @@ -272,4 +280,4 @@ def register_update_parser(parser: argparse.ArgumentParser) -> None: type=str, help="Address of the machine to build the flake, in the format of user@host:1234.", ) - parser.set_defaults(func=update) + parser.set_defaults(func=update_command) diff --git a/pkgs/clan-cli/clan_cli/nix/__init__.py b/pkgs/clan-cli/clan_cli/nix/__init__.py index cba4e3562..0afe3109d 100644 --- a/pkgs/clan-cli/clan_cli/nix/__init__.py +++ b/pkgs/clan-cli/clan_cli/nix/__init__.py @@ -49,7 +49,7 @@ def nix_add_to_gcroots(nix_path: Path, dest: Path) -> None: def nix_config() -> dict[str, Any]: - cmd = nix_command(["show-config", "--json"]) + cmd = nix_command(["config", "show", "--json"]) proc = run_no_stdout(cmd) data = json.loads(proc.stdout) config = {} diff --git a/pkgs/clan-cli/clan_cli/ssh/host_group.py b/pkgs/clan-cli/clan_cli/ssh/host_group.py deleted file mode 100644 index e3c12c95f..000000000 --- a/pkgs/clan-cli/clan_cli/ssh/host_group.py +++ /dev/null @@ -1,214 +0,0 @@ -import logging -from collections.abc import Callable -from dataclasses import dataclass -from threading import Thread -from typing import Any - -from clan_cli.cmd import RunOpts -from clan_cli.errors import ClanError -from clan_cli.ssh import T -from clan_cli.ssh.host import Host -from clan_cli.ssh.results import HostResult, Results - -cmdlog = logging.getLogger(__name__) - - -def _worker( - func: Callable[[Host], T], - host: Host, - results: list[HostResult[T]], - idx: int, -) -> None: - try: - results[idx] = HostResult(host, func(host)) - except Exception as e: - results[idx] = HostResult(host, e) - - -@dataclass -class HostGroup: - hosts: list[Host] - - def _run_local( - self, - *, - cmd: list[str], - opts: RunOpts, - extra_env: dict[str, str] | None, - host: Host, - results: Results, - verbose_ssh: bool = False, - tty: bool = False, - ) -> None: - try: - proc = host.run_local( - cmd, - opts, - extra_env, - ) - results.append(HostResult(host, proc)) - except Exception as e: - results.append(HostResult(host, e)) - - def _run_remote( - self, - cmd: list[str], - opts: RunOpts, - host: Host, - results: Results, - extra_env: dict[str, str] | None = None, - verbose_ssh: bool = False, - tty: bool = False, - ) -> None: - try: - proc = host.run( - cmd, - extra_env=extra_env, - verbose_ssh=verbose_ssh, - opts=opts, - tty=tty, - ) - results.append(HostResult(host, proc)) - except Exception as e: - results.append(HostResult(host, e)) - - def _reraise_errors(self, results: list[HostResult[Any]]) -> None: - errors = 0 - for result in results: - e = result.error - if e: - cmdlog.error( - f"failed with: {e}", - extra={"command_prefix": result.host.command_prefix}, - ) - errors += 1 - if errors > 0: - msg = f"{errors} hosts failed with an error. Check the logs above" - raise ClanError(msg) from e - - def _run( - self, - cmd: list[str], - opts: RunOpts | None = None, - local: bool = False, - extra_env: dict[str, str] | None = None, - verbose_ssh: bool = False, - tty: bool = False, - ) -> Results: - results: Results = [] - threads = [] - - if opts is None: - opts = RunOpts() - - for host in self.hosts: - if local: - thread = Thread( - target=self._run_local, - kwargs={ - "cmd": cmd, - "opts": opts, - "host": host, - "results": results, - "extra_env": extra_env, - "verbose_ssh": verbose_ssh, - "tty": tty, - }, - ) - else: - thread = Thread( - target=self._run_remote, - kwargs={ - "results": results, - "cmd": cmd, - "host": host, - "opts": opts, - "extra_env": extra_env, - "verbose_ssh": verbose_ssh, - "tty": tty, - }, - ) - - thread.start() - threads.append(thread) - - for thread in threads: - thread.join() - - if opts.check: - self._reraise_errors(results) - - return results - - def run( - self, - cmd: list[str], - opts: RunOpts | None = None, - *, - extra_env: dict[str, str] | None = None, - verbose_ssh: bool = False, - tty: bool = False, - ) -> Results: - """ - Command to run on the remote host via ssh - - @return a lists of tuples containing Host and the result of the command for this Host - """ - return self._run( - cmd, - opts, - extra_env=extra_env, - verbose_ssh=verbose_ssh, - tty=tty, - ) - - def run_local( - self, - cmd: list[str], - opts: RunOpts | None = None, - *, - extra_env: dict[str, str] | None = None, - ) -> Results: - """ - Command to run locally for each host in the group in parallel - - @return a lists of tuples containing Host and the result of the command for this Host - """ - - return self._run( - cmd, - opts, - local=True, - extra_env=extra_env, - ) - - def run_function( - self, func: Callable[[Host], T], check: bool = True - ) -> list[HostResult[T]]: - """ - Function to run for each host in the group in parallel - """ - threads = [] - results: list[HostResult[T]] = [ - HostResult(h, ClanError(f"No result set for thread {i}")) - for (i, h) in enumerate(self.hosts) - ] - for i, host in enumerate(self.hosts): - thread = Thread( - target=_worker, - args=(func, host, results, i), - ) - threads.append(thread) - - for thread in threads: - thread.start() - - for thread in threads: - thread.join() - if check: - self._reraise_errors(results) - return results - - def filter(self, pred: Callable[[Host], bool]) -> "HostGroup": - """Return a new Group with the results filtered by the predicate""" - return HostGroup(list(filter(pred, self.hosts))) diff --git a/pkgs/clan-cli/clan_cli/vars/generate.py b/pkgs/clan-cli/clan_cli/vars/generate.py index dd9598636..9afa7d431 100644 --- a/pkgs/clan-cli/clan_cli/vars/generate.py +++ b/pkgs/clan-cli/clan_cli/vars/generate.py @@ -488,7 +488,8 @@ def generate_vars( raise ClanError(msg) from errors[0][1] if not was_regenerated and len(machines) > 0: - log.info("All vars are already up to date") + for machine in machines: + machine.info("All vars are already up to date") return was_regenerated diff --git a/pkgs/clan-cli/tests/conftest.py b/pkgs/clan-cli/tests/conftest.py index 1532a8618..947b4fbc2 100644 --- a/pkgs/clan-cli/tests/conftest.py +++ b/pkgs/clan-cli/tests/conftest.py @@ -12,7 +12,8 @@ pytest_plugins = [ "sshd", "command", "ports", - "host_group", + "hosts", + "runtime", "fixtures_flakes", "stdout", "nix_config", diff --git a/pkgs/clan-cli/tests/host_group.py b/pkgs/clan-cli/tests/host_group.py deleted file mode 100644 index ec04e1b01..000000000 --- a/pkgs/clan-cli/tests/host_group.py +++ /dev/null @@ -1,25 +0,0 @@ -import os -import pwd - -import pytest -from clan_cli.ssh.host import Host -from clan_cli.ssh.host_group import HostGroup -from clan_cli.ssh.host_key import HostKeyCheck -from sshd import Sshd - - -@pytest.fixture -def host_group(sshd: Sshd) -> HostGroup: - login = pwd.getpwuid(os.getuid()).pw_name - group = HostGroup( - [ - Host( - "127.0.0.1", - port=sshd.port, - user=login, - key=sshd.key, - host_key_check=HostKeyCheck.NONE, - ) - ] - ) - return group diff --git a/pkgs/clan-cli/tests/hosts.py b/pkgs/clan-cli/tests/hosts.py new file mode 100644 index 000000000..5f0b8cf60 --- /dev/null +++ b/pkgs/clan-cli/tests/hosts.py @@ -0,0 +1,23 @@ +import os +import pwd + +import pytest +from clan_cli.ssh.host import Host +from clan_cli.ssh.host_key import HostKeyCheck +from sshd import Sshd + + +@pytest.fixture +def hosts(sshd: Sshd) -> list[Host]: + login = pwd.getpwuid(os.getuid()).pw_name + group = [ + Host( + "127.0.0.1", + port=sshd.port, + user=login, + key=sshd.key, + host_key_check=HostKeyCheck.NONE, + ) + ] + + return group diff --git a/pkgs/clan-cli/tests/runtime.py b/pkgs/clan-cli/tests/runtime.py new file mode 100644 index 000000000..a13a992ff --- /dev/null +++ b/pkgs/clan-cli/tests/runtime.py @@ -0,0 +1,7 @@ +import pytest +from clan_cli.async_run import AsyncRuntime + + +@pytest.fixture +def runtime() -> AsyncRuntime: + return AsyncRuntime() diff --git a/pkgs/clan-cli/tests/test_api_dataclass_compat.py b/pkgs/clan-cli/tests/test_api_dataclass_compat.py index 34023df3b..8f63b9722 100644 --- a/pkgs/clan-cli/tests/test_api_dataclass_compat.py +++ b/pkgs/clan-cli/tests/test_api_dataclass_compat.py @@ -120,6 +120,7 @@ def test_all_dataclasses() -> None: excludes = [ "api/__init__.py", "cmd.py", # We don't want the UI to have access to the cmd module anyway + "async_run.py", # We don't want the UI to have access to the async_run module anyway ] cli_path = Path("clan_cli").resolve() diff --git a/pkgs/clan-cli/tests/test_secrets_password_store.py b/pkgs/clan-cli/tests/test_secrets_password_store.py index 19ecddd9e..bdfe71e94 100644 --- a/pkgs/clan-cli/tests/test_secrets_password_store.py +++ b/pkgs/clan-cli/tests/test_secrets_password_store.py @@ -7,7 +7,7 @@ from clan_cli.facts.secret_modules.password_store import SecretStore from clan_cli.machines.facts import machine_get_fact from clan_cli.machines.machines import Machine from clan_cli.nix import nix_shell -from clan_cli.ssh.host_group import HostGroup +from clan_cli.ssh.host import Host from fixtures_flakes import ClanFlake from helpers import cli @@ -17,7 +17,7 @@ def test_upload_secret( monkeypatch: pytest.MonkeyPatch, flake: ClanFlake, temporary_home: Path, - host_group: HostGroup, + hosts: list[Host], ) -> None: flake.clan_modules = [ "root-password", @@ -27,7 +27,7 @@ def test_upload_secret( config = flake.machines["vm1"] config["nixpkgs"]["hostPlatform"] = "x86_64-linux" config["clan"]["core"]["networking"]["zerotier"]["controller"]["enable"] = True - host = host_group.hosts[0] + host = hosts[0] addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}" config["clan"]["core"]["networking"]["targetHost"] = addr config["clan"]["user-password"]["user"] = "alice" diff --git a/pkgs/clan-cli/tests/test_secrets_upload.py b/pkgs/clan-cli/tests/test_secrets_upload.py index 7e667f43f..058bd0b7a 100644 --- a/pkgs/clan-cli/tests/test_secrets_upload.py +++ b/pkgs/clan-cli/tests/test_secrets_upload.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING import pytest -from clan_cli.ssh.host_group import HostGroup +from clan_cli.ssh.host import Host from fixtures_flakes import FlakeForTest from helpers import cli @@ -13,7 +13,7 @@ if TYPE_CHECKING: def test_secrets_upload( monkeypatch: pytest.MonkeyPatch, test_flake_with_core: FlakeForTest, - host_group: HostGroup, + hosts: list[Host], age_keys: list["KeyPair"], ) -> None: monkeypatch.chdir(test_flake_with_core.path) @@ -48,7 +48,7 @@ def test_secrets_upload( ) flake = test_flake_with_core.path.joinpath("flake.nix") - host = host_group.hosts[0] + host = hosts[0] addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}" new_text = flake.read_text().replace("__CLAN_TARGET_ADDRESS__", addr) diff --git a/pkgs/clan-cli/tests/test_ssh_local.py b/pkgs/clan-cli/tests/test_ssh_local.py index 3c0b175a0..ba1a79622 100644 --- a/pkgs/clan-cli/tests/test_ssh_local.py +++ b/pkgs/clan-cli/tests/test_ssh_local.py @@ -1,70 +1,49 @@ -from clan_cli.cmd import Log, RunOpts +from clan_cli.async_run import AsyncRuntime +from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts from clan_cli.ssh.host import Host -from clan_cli.ssh.host_group import HostGroup -hosts = HostGroup([Host("some_host")]) +host = Host("some_host") -def test_run_environment() -> None: - p2 = hosts.run_local( +def test_run_environment(runtime: AsyncRuntime) -> None: + p2 = runtime.async_run( + None, + host.run_local, ["echo $env_var"], RunOpts(shell=True, log=Log.STDERR), extra_env={"env_var": "true"}, ) - assert p2[0].result.stdout == "true\n" - p3 = hosts.run_local( - ["env"], RunOpts(shell=True, log=Log.STDERR), extra_env={"env_var": "true"} + assert p2.wait().result.stdout == "true\n" + + p3 = runtime.async_run( + None, + host.run_local, + ["env"], + RunOpts(shell=True, log=Log.STDERR), + extra_env={"env_var": "true"}, ) - assert "env_var=true" in p3[0].result.stdout + assert "env_var=true" in p3.wait().result.stdout -def test_run_local() -> None: - hosts.run_local(["echo", "hello"]) +def test_run_local(runtime: AsyncRuntime) -> None: + p1 = runtime.async_run( + None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR) + ) + assert p1.wait().result.stdout == "hello\n" -def test_timeout() -> None: - try: - hosts.run_local(["sleep", "10"], RunOpts(timeout=0.01)) - except Exception: - pass - else: - msg = "should have raised TimeoutExpired" - raise AssertionError(msg) +def test_timeout(runtime: AsyncRuntime) -> None: + p1 = runtime.async_run(None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01)) + error = p1.wait().error + assert isinstance(error, ClanCmdTimeoutError) -def test_run_function() -> None: - def some_func(h: Host) -> bool: - par = h.run_local(["echo", "hello"], RunOpts(log=Log.STDERR)) - return par.stdout == "hello\n" - - res = hosts.run_function(some_func) - assert res[0].result +def test_run_exception(runtime: AsyncRuntime) -> None: + p1 = runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True)) + assert p1.wait().error is not None -def test_run_exception() -> None: - try: - hosts.run_local(["exit 1"], RunOpts(shell=True)) - except Exception: - pass - else: - msg = "should have raised Exception" - raise AssertionError(msg) - - -def test_run_function_exception() -> None: - def some_func(h: Host) -> None: - h.run_local(["exit 1"], RunOpts(shell=True)) - - try: - hosts.run_function(some_func) - except Exception: - pass - else: - msg = "should have raised Exception" - raise AssertionError(msg) - - -def test_run_local_non_shell() -> None: - p2 = hosts.run_local(["echo", "1"], RunOpts(log=Log.STDERR)) - assert p2[0].result.stdout == "1\n" +def test_run_local_non_shell(runtime: AsyncRuntime) -> None: + p2 = runtime.async_run(None, host.run_local, ["echo", "1"], RunOpts(log=Log.STDERR)) + assert p2.wait().result.stdout == "1\n" diff --git a/pkgs/clan-cli/tests/test_ssh_remote.py b/pkgs/clan-cli/tests/test_ssh_remote.py index 8946f3bc5..dc6a3796f 100644 --- a/pkgs/clan-cli/tests/test_ssh_remote.py +++ b/pkgs/clan-cli/tests/test_ssh_remote.py @@ -1,8 +1,8 @@ import pytest -from clan_cli.cmd import Log, RunOpts +from clan_cli.async_run import AsyncRuntime +from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts from clan_cli.errors import ClanError, CmdOut from clan_cli.ssh.host import Host -from clan_cli.ssh.host_group import HostGroup from clan_cli.ssh.host_key import HostKeyCheck from clan_cli.ssh.parse import parse_deployment_address @@ -20,52 +20,75 @@ def test_parse_ipv6() -> None: host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT) -def test_run(host_group: HostGroup) -> None: - proc = host_group.run_local(["echo", "hello"], RunOpts(log=Log.STDERR)) - assert proc[0].result.stdout == "hello\n" +def test_run(hosts: list[Host], runtime: AsyncRuntime) -> None: + for host in hosts: + proc = runtime.async_run( + None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR) + ) + assert proc.wait().result.stdout == "hello\n" -def test_run_environment(host_group: HostGroup) -> None: - p1 = host_group.run( - ["echo $env_var"], - RunOpts(shell=True, log=Log.STDERR), - extra_env={"env_var": "true"}, - ) - assert p1[0].result.stdout == "true\n" - p2 = host_group.run(["env"], RunOpts(log=Log.STDERR), extra_env={"env_var": "true"}) - assert "env_var=true" in p2[0].result.stdout +def test_run_environment(hosts: list[Host], runtime: AsyncRuntime) -> None: + for host in hosts: + proc = runtime.async_run( + None, + host.run_local, + ["echo $env_var"], + RunOpts(shell=True, log=Log.STDERR), + extra_env={"env_var": "true"}, + ) + assert proc.wait().result.stdout == "true\n" + + for host in hosts: + p2 = runtime.async_run( + None, + host.run_local, + ["env"], + RunOpts(log=Log.STDERR), + extra_env={"env_var": "true"}, + ) + assert "env_var=true" in p2.wait().result.stdout -def test_run_no_shell(host_group: HostGroup) -> None: - proc = host_group.run(["echo", "$hello"], RunOpts(log=Log.STDERR)) - assert proc[0].result.stdout == "$hello\n" +def test_run_no_shell(hosts: list[Host], runtime: AsyncRuntime) -> None: + for host in hosts: + proc = runtime.async_run( + None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR) + ) + assert proc.wait().result.stdout == "hello\n" -def test_run_function(host_group: HostGroup) -> None: +def test_run_function(hosts: list[Host], runtime: AsyncRuntime) -> None: def some_func(h: Host) -> bool: p = h.run(["echo", "hello"]) return p.stdout == "hello\n" - res = host_group.run_function(some_func) - assert res[0].result + for host in hosts: + proc = runtime.async_run(None, some_func, host) + assert proc.wait().result -def test_timeout(host_group: HostGroup) -> None: - try: - host_group.run_local(["sleep", "10"], RunOpts(timeout=0.01)) - except Exception: - pass - else: - msg = "should have raised TimeoutExpired" - raise AssertionError(msg) +def test_timeout(hosts: list[Host], runtime: AsyncRuntime) -> None: + for host in hosts: + proc = runtime.async_run( + None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01) + ) + error = proc.wait().error + assert isinstance(error, ClanCmdTimeoutError) -def test_run_exception(host_group: HostGroup) -> None: - r = host_group.run(["exit 1"], RunOpts(check=False, shell=True)) - assert r[0].result.returncode == 1 +def test_run_exception(hosts: list[Host], runtime: AsyncRuntime) -> None: + for host in hosts: + proc = runtime.async_run( + None, host.run_local, ["exit 1"], RunOpts(shell=True, check=False) + ) + assert proc.wait().result.returncode == 1 try: - host_group.run(["exit 1"], RunOpts(shell=True)) + for host in hosts: + runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True)) + runtime.join_all() + runtime.check_all() except Exception: pass else: @@ -73,12 +96,15 @@ def test_run_exception(host_group: HostGroup) -> None: raise AssertionError(msg) -def test_run_function_exception(host_group: HostGroup) -> None: +def test_run_function_exception(hosts: list[Host], runtime: AsyncRuntime) -> None: def some_func(h: Host) -> CmdOut: return h.run_local(["exit 1"], RunOpts(shell=True)) try: - host_group.run_function(some_func) + for host in hosts: + runtime.async_run(None, some_func, host) + runtime.join_all() + runtime.check_all() except Exception: pass else: