Merge pull request 'clan-cli: Replace HostGroup and MachineGroup with generic AsyncRuntime class. Propagate cmd prefix over thread local. Close threads on CTRL+C' (#2580) from Qubasa/clan-core:Qubasa-main into main

This commit is contained in:
clan-bot
2024-12-09 17:13:36 +00:00
18 changed files with 549 additions and 436 deletions

View File

@@ -434,8 +434,8 @@ def main() -> None:
else:
log.error("%s", e) # noqa: TRY400
sys.exit(1)
except KeyboardInterrupt:
log.warning("Interrupted by user")
except KeyboardInterrupt as ex:
log.warning("Interrupted by user", exc_info=ex)
sys.exit(1)

View File

@@ -0,0 +1,312 @@
import logging
import threading
import time
import types
import uuid
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import IO, Any, Generic, ParamSpec, TypeVar
from clan_cli.errors import ClanError
log = logging.getLogger(__name__)
# Why did we create a custom AsyncRuntime instead of using asyncio?
#
# The AsyncRuntime class allows us to run functions in separate threads for asynchronous
# execution without requiring the use of the "async" keyword. By using this approach,
# functions can gracefully handle cancellation by checking if get_async_ctx().cancel is True.
#
# There was some resistance to using asyncio, partly due to challenges we faced when
# implementing it in our first web interface. Threads felt simpler and more familiar
# for our use case.
#
# That said, asyncio is generally more efficient because it uses non-blocking I/O
# and avoids issues with Python's Global Interpreter Lock (GIL), which can limit the
# performance of threads in CPU-heavy workloads.
#
# Using threads works well for us because most of the time is spent waiting for commands
# or external processes to finish, rather than performing heavy computing tasks.
#
# Note: Starting with Python 3.14, the GIL can be disabled to enable true parallelism.
# However, disabling the GIL introduces a 10-40% performance cost for Python code
# due to the overhead of additional locking.
# Define generics for return type and call signature
R = TypeVar("R") # Return type of the callable
P = ParamSpec("P") # Parameters of the callable
@dataclass
class AsyncResult(Generic[R]):
_result: R | Exception
@property
def error(self) -> Exception | None:
"""
Returns an error if the callable raised an exception.
"""
if isinstance(self._result, Exception):
return self._result
return None
@property
def result(self) -> R:
"""
Unwraps and returns the result if no exception occurred.
Raises the exception otherwise.
"""
if isinstance(self._result, Exception):
raise self._result
return self._result
@dataclass
class AsyncContext:
"""
This class stores thread-local data.
"""
prefix: str | None = None # prefix for logging
stdout: IO[bytes] | None = None # stdout of subprocesses
stderr: IO[bytes] | None = None # stderr of subprocesses
cancel: bool = False # Used to signal cancellation of task
@dataclass
class AsyncOpts:
"""
Options for the async_run function.
"""
tid: str | None = None
check: bool = True
async_ctx: AsyncContext = field(default_factory=AsyncContext)
ASYNC_CTX_THREAD_LOCAL = threading.local()
def is_async_cancelled() -> bool:
"""
Check if the current task has been cancelled.
"""
return get_async_ctx().cancel
def get_async_ctx() -> AsyncContext:
"""
Retrieve the current AsyncContext, creating a new one if none exists.
"""
global ASYNC_CTX_THREAD_LOCAL
if not hasattr(ASYNC_CTX_THREAD_LOCAL, "async_ctx"):
ASYNC_CTX_THREAD_LOCAL.async_ctx = AsyncContext()
return ASYNC_CTX_THREAD_LOCAL.async_ctx
def set_async_ctx(ctx: AsyncContext) -> None:
global ASYNC_CTX_THREAD_LOCAL
ASYNC_CTX_THREAD_LOCAL.async_ctx = ctx
class AsyncThread(threading.Thread, Generic[P, R]):
function: Callable[P, R]
args: Any
kwargs: Any
result: AsyncResult[R] | None
finished: bool
condition: threading.Condition
async_opts: AsyncOpts
def __init__(
self,
async_opts: AsyncOpts,
condition: threading.Condition,
function: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
A threaded wrapper for running a function asynchronously.
"""
super().__init__()
self.function = function
self.args = args
self.kwargs = kwargs
self.result: AsyncResult[R] | None = None # Store the result or exception
self.finished = False # Set to True after the thread finishes execution
self.condition = condition # Shared condition variable
self.async_opts = async_opts
def run(self) -> None:
"""
Run the function in a separate thread.
"""
try:
set_async_ctx(self.async_opts.async_ctx)
self.result = AsyncResult(_result=self.function(*self.args, **self.kwargs))
except Exception as ex:
self.result = AsyncResult(_result=ex)
finally:
self.finished = True
# Acquire the condition lock before notifying
with self.condition:
self.condition.notify_all() # Notify waiting threads that this thread is done
@dataclass
class AsyncFuture(Generic[R]):
_tid: str
_runtime: "AsyncRuntime"
def wait(self) -> AsyncResult[R]:
"""
Wait for the task to finish.
"""
if self._tid not in self._runtime.tasks:
msg = f"No task with the name '{self._tid}' exists."
raise ClanError(msg)
thread = self._runtime.tasks[self._tid]
thread.join()
result = self.get_result()
if result is None:
msg = f"Task '{self._tid}' unexpectedly returned None."
raise ClanError(msg)
return result
def get_result(self) -> AsyncResult[R] | None:
"""
Retrieve the result of a finished task and remove it from the task list.
"""
if self._tid not in self._runtime.tasks:
msg = f"No task with the name '{self._tid}' exists."
raise ClanError(msg)
thread = self._runtime.tasks[self._tid]
if not thread.finished:
return None
# Remove the task after retrieving the result
result = thread.result
del self._runtime.tasks[self._tid]
if result is None:
msg = f"The result for task '{self._tid}' is unexpectedly None."
raise ClanError(msg)
return result
@dataclass
class AsyncRuntime:
tasks: dict[str, AsyncThread[Any, Any]] = field(default_factory=dict)
condition: threading.Condition = field(default_factory=threading.Condition)
def async_run(
self,
opts: AsyncOpts | None,
function: Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> AsyncFuture[R]:
"""
Run the given function asynchronously in a thread with a specific name and arguments.
The function's static typing is preserved.
"""
if opts is None:
opts = AsyncOpts()
if opts.tid is None:
opts.tid = uuid.uuid4().hex
if opts.tid in self.tasks:
msg = f"A task with the name '{opts.tid}' is already running."
raise ClanError(msg)
# Create and start the new AsyncThread
thread = AsyncThread(opts, self.condition, function, *args, **kwargs)
self.tasks[opts.tid] = thread
thread.start()
return AsyncFuture(opts.tid, self)
def join_all(self) -> None:
"""
Wait for all tasks to finish
"""
with self.condition:
while any(
not task.finished for task in self.tasks.values()
): # Check if any tasks are still running
self.condition.wait() # Wait until a thread signals completion
def check_all(self) -> None:
"""
Check if there where any errors
"""
err_count = 0
for name, task in self.tasks.items():
if task.finished and task.async_opts.check:
assert task.result is not None
error = task.result.error
if log.isEnabledFor(logging.DEBUG):
log.error(
f"failed with error: {error}",
extra={"command_prefix": name},
exc_info=error,
)
else:
log.error(
f"failed with error: {error}", extra={"command_prefix": name}
)
err_count += 1
if err_count > 0:
msg = f"{err_count} hosts failed with an error. Check the logs above"
raise ClanError(msg)
def __enter__(self) -> "AsyncRuntime":
"""
Enter the runtime context related to this object.
"""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
"""
Exit the runtime context related to this object.
Sets async_ctx.cancel to True to signal cancellation.
"""
for name, task in self.tasks.items():
if not task.finished:
task.async_opts.async_ctx.cancel = True
log.debug(f"Canceling task {name}")
# Example usage
if __name__ == "__main__":
runtime = AsyncRuntime()
def add(a: int, b: int) -> int:
return a + b
def concatenate(a: str, b: str) -> str:
time.sleep(1)
msg = "Hello World"
raise ClanError(msg)
with runtime:
p1 = runtime.async_run(None, add, 1, 2)
p2 = runtime.async_run(None, concatenate, "Hello ", "World")
add_result = p1.wait()
print(add_result.result) # Output: 3
concat_result = p2.wait()
print(concat_result.error) # Output: Hello World

View File

@@ -17,6 +17,7 @@ from enum import Enum
from pathlib import Path
from typing import IO, Any
from clan_cli.async_run import get_async_ctx, is_async_cancelled
from clan_cli.colors import Color
from clan_cli.custom_logger import print_trace
from clan_cli.errors import ClanCmdError, ClanError, CmdOut, indent_command
@@ -87,7 +88,8 @@ def handle_io(
stdout_extra = {}
stderr_extra = {}
if prefix:
stdout_extra["command_prefix"] = stderr_extra["command_prefix"] = prefix
stdout_extra["command_prefix"] = prefix
stderr_extra["command_prefix"] = prefix
if msg_color and msg_color.stderr:
stdout_extra["color"] = msg_color.stderr.value
if msg_color and msg_color.stdout:
@@ -101,6 +103,13 @@ def handle_io(
description = prefix
raise ClanCmdTimeoutError(msg=msg, description=description, timeout=timeout)
# Check if the command has been cancelled
if is_async_cancelled():
cmdlog.warning("Command cancelled", extra=stderr_extra)
# Terminate process
break
# Wait for data to be available
readlist, writelist, _ = select.select(rlist, wlist, [], 0.1)
if len(readlist) == 0 and len(writelist) == 0:
@@ -116,7 +125,7 @@ def handle_io(
# If Log.STDOUT is set, log the stdout output
if ret and log in [Log.STDOUT, Log.BOTH]:
lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
lines = ret.decode("utf-8", "replace").rstrip("\n").rstrip().split("\n")
for line in lines:
cmdlog.info(line, extra=stdout_extra)
@@ -133,7 +142,7 @@ def handle_io(
# If Log.STDERR is set, log the stderr output
if ret and log in [Log.STDERR, Log.BOTH]:
lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
lines = ret.decode("utf-8", "replace").rstrip("\n").rstrip().split("\n")
for line in lines:
cmdlog.info(line, extra=stderr_extra)
@@ -273,6 +282,17 @@ def run(
if options.cwd is None:
options.cwd = Path.cwd()
async_ctx = get_async_ctx()
# Fill in the options from the thread-local data
# if they are not set in the options
if async_ctx:
if options.prefix is None:
options.prefix = async_ctx.prefix
if options.stdout is None:
options.stdout = async_ctx.stdout
if options.stderr is None:
options.stderr = async_ctx.stderr
if options.input:
if any(not ch.isprintable() for ch in options.input.decode("ascii", "replace")):
filtered_input = "<<binary_blob>>"
@@ -318,7 +338,8 @@ def run(
stdout=options.stdout,
stderr=options.stderr,
)
process.wait()
if not is_async_cancelled():
process.wait()
global TIME_TABLE
if TIME_TABLE:
@@ -336,7 +357,9 @@ def run(
)
if options.check and process.returncode != 0:
raise ClanCmdError(cmd_out)
err = ClanCmdError(cmd_out)
err.msg = "Command has been cancelled"
raise err
return cmd_out
@@ -359,3 +382,6 @@ def run_no_stdout(
cmd,
opts,
)
# type: ignore

View File

@@ -59,6 +59,9 @@ class PrefixFormatter(logging.Formatter):
# If command_prefix is set, color the prefix with a unique color.
elif command_prefix:
command_prefix = command_prefix[
:10
] # Truncate the command prefix to 10 characters.
prefix_color = self.hostname_colorcode(command_prefix)
format_str = color_by_tuple(f"[{command_prefix}]", fg=prefix_color)
format_str += color_by_tuple(" %(message)s", fg=msg_color)

View File

@@ -1,35 +0,0 @@
from collections.abc import Callable
from typing import TypeVar
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup, HostResult
from .machines import Machine
T = TypeVar("T")
class MachineGroup:
def __init__(self, machines: list[Machine]) -> None:
self.machines = machines
self.group = HostGroup([m.target_host for m in machines])
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return f"MachineGroup({self.group})"
def run_function(
self, func: Callable[[Machine], T], check: bool = True
) -> list[HostResult[T]]:
"""
Function to run for each host in the group in parallel
@func the function to call
"""
def wrapped_func(host: Host) -> T:
return func(host.meta["machine"])
return self.group.run_function(wrapped_func, check=check)

View File

@@ -6,6 +6,7 @@ import shlex
import sys
from clan_cli.api import API
from clan_cli.async_run import AsyncContext, AsyncOpts, AsyncRuntime, is_async_cancelled
from clan_cli.clan_uri import FlakeId
from clan_cli.cmd import MsgColor, RunOpts, run
from clan_cli.colors import AnsiColor
@@ -24,7 +25,6 @@ from clan_cli.vars.generate import generate_vars
from clan_cli.vars.upload import upload_secret_vars
from .inventory import get_all_machines, get_selected_machines
from .machine_group import MachineGroup
log = logging.getLogger(__name__)
@@ -113,10 +113,10 @@ def update_machines(base_path: str, machines: list[InventoryMachine]) -> None:
# m.override_build_host = machine.deploy.buildHost
group_machines.append(m)
deploy_machine(MachineGroup(group_machines))
deploy_machine(group_machines)
def deploy_machine(machines: MachineGroup) -> None:
def deploy_machine(machines: list[Machine]) -> None:
"""
Deploy to all hosts in parallel
"""
@@ -163,11 +163,9 @@ def deploy_machine(machines: MachineGroup) -> None:
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env,
)
ret = host.run(
switch_cmd,
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
extra_env=env,
)
if is_async_cancelled():
return
# if the machine is mobile, we retry to deploy with the mobile workaround method
is_mobile = machine.deployment.get("nixosMobileWorkaround", False)
@@ -187,62 +185,72 @@ def deploy_machine(machines: MachineGroup) -> None:
extra_env=env,
)
if len(machines.group.hosts) > 1:
machines.run_function(deploy)
else:
deploy(machines.machines[0])
with AsyncRuntime() as runtime:
for machine in machines:
machine.info(f"Updating {machine.name}")
runtime.async_run(
AsyncOpts(
tid=machine.name, async_ctx=AsyncContext(prefix=machine.name)
),
deploy,
machine,
)
runtime.join_all()
def update(args: argparse.Namespace) -> None:
if args.flake is None:
msg = "Could not find clan flake toplevel directory"
raise ClanError(msg)
machines = []
if len(args.machines) == 1 and args.target_host is not None:
machine = Machine(
name=args.machines[0], flake=args.flake, nix_options=args.option
)
machine.override_target_host = args.target_host
machine.override_build_host = args.build_host
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
machines.append(machine)
elif args.target_host is not None:
print("target host can only be specified for a single machine")
exit(1)
else:
if len(args.machines) == 0:
ignored_machines = []
for machine in get_all_machines(args.flake, args.option):
if machine.deployment.get("requireExplicitUpdate", False):
continue
try:
machine.build_host # noqa: B018
except ClanError: # check if we have a build host set
ignored_machines.append(machine)
continue
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
machine.override_build_host = args.build_host
machines.append(machine)
if not machines and ignored_machines != []:
print(
"WARNING: No machines to update."
"The following defined machines were ignored because they"
"do not have the `clan.core.networking.targetHost` nixos option set:",
file=sys.stderr,
)
for machine in ignored_machines:
print(machine, file=sys.stderr)
def update_command(args: argparse.Namespace) -> None:
try:
if args.flake is None:
msg = "Could not find clan flake toplevel directory"
raise ClanError(msg)
machines = []
if len(args.machines) == 1 and args.target_host is not None:
machine = Machine(
name=args.machines[0], flake=args.flake, nix_options=args.option
)
machine.override_target_host = args.target_host
machine.override_build_host = args.build_host
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
machines.append(machine)
elif args.target_host is not None:
print("target host can only be specified for a single machine")
exit(1)
else:
machines = get_selected_machines(args.flake, args.option, args.machines)
for machine in machines:
machine.override_build_host = args.build_host
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
if len(args.machines) == 0:
ignored_machines = []
for machine in get_all_machines(args.flake, args.option):
if machine.deployment.get("requireExplicitUpdate", False):
continue
try:
machine.build_host # noqa: B018
except ClanError: # check if we have a build host set
ignored_machines.append(machine)
continue
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
machine.override_build_host = args.build_host
machines.append(machine)
host_group = MachineGroup(machines)
deploy_machine(host_group)
if not machines and ignored_machines != []:
print(
"WARNING: No machines to update."
"The following defined machines were ignored because they"
"do not have the `clan.core.networking.targetHost` nixos option set:",
file=sys.stderr,
)
for machine in ignored_machines:
print(machine, file=sys.stderr)
else:
machines = get_selected_machines(args.flake, args.option, args.machines)
for machine in machines:
machine.override_build_host = args.build_host
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
deploy_machine(machines)
except KeyboardInterrupt:
log.warning("Interrupted by user")
sys.exit(1)
def register_update_parser(parser: argparse.ArgumentParser) -> None:
@@ -272,4 +280,4 @@ def register_update_parser(parser: argparse.ArgumentParser) -> None:
type=str,
help="Address of the machine to build the flake, in the format of user@host:1234.",
)
parser.set_defaults(func=update)
parser.set_defaults(func=update_command)

View File

@@ -49,7 +49,7 @@ def nix_add_to_gcroots(nix_path: Path, dest: Path) -> None:
def nix_config() -> dict[str, Any]:
cmd = nix_command(["show-config", "--json"])
cmd = nix_command(["config", "show", "--json"])
proc = run_no_stdout(cmd)
data = json.loads(proc.stdout)
config = {}

View File

@@ -1,214 +0,0 @@
import logging
from collections.abc import Callable
from dataclasses import dataclass
from threading import Thread
from typing import Any
from clan_cli.cmd import RunOpts
from clan_cli.errors import ClanError
from clan_cli.ssh import T
from clan_cli.ssh.host import Host
from clan_cli.ssh.results import HostResult, Results
cmdlog = logging.getLogger(__name__)
def _worker(
func: Callable[[Host], T],
host: Host,
results: list[HostResult[T]],
idx: int,
) -> None:
try:
results[idx] = HostResult(host, func(host))
except Exception as e:
results[idx] = HostResult(host, e)
@dataclass
class HostGroup:
hosts: list[Host]
def _run_local(
self,
*,
cmd: list[str],
opts: RunOpts,
extra_env: dict[str, str] | None,
host: Host,
results: Results,
verbose_ssh: bool = False,
tty: bool = False,
) -> None:
try:
proc = host.run_local(
cmd,
opts,
extra_env,
)
results.append(HostResult(host, proc))
except Exception as e:
results.append(HostResult(host, e))
def _run_remote(
self,
cmd: list[str],
opts: RunOpts,
host: Host,
results: Results,
extra_env: dict[str, str] | None = None,
verbose_ssh: bool = False,
tty: bool = False,
) -> None:
try:
proc = host.run(
cmd,
extra_env=extra_env,
verbose_ssh=verbose_ssh,
opts=opts,
tty=tty,
)
results.append(HostResult(host, proc))
except Exception as e:
results.append(HostResult(host, e))
def _reraise_errors(self, results: list[HostResult[Any]]) -> None:
errors = 0
for result in results:
e = result.error
if e:
cmdlog.error(
f"failed with: {e}",
extra={"command_prefix": result.host.command_prefix},
)
errors += 1
if errors > 0:
msg = f"{errors} hosts failed with an error. Check the logs above"
raise ClanError(msg) from e
def _run(
self,
cmd: list[str],
opts: RunOpts | None = None,
local: bool = False,
extra_env: dict[str, str] | None = None,
verbose_ssh: bool = False,
tty: bool = False,
) -> Results:
results: Results = []
threads = []
if opts is None:
opts = RunOpts()
for host in self.hosts:
if local:
thread = Thread(
target=self._run_local,
kwargs={
"cmd": cmd,
"opts": opts,
"host": host,
"results": results,
"extra_env": extra_env,
"verbose_ssh": verbose_ssh,
"tty": tty,
},
)
else:
thread = Thread(
target=self._run_remote,
kwargs={
"results": results,
"cmd": cmd,
"host": host,
"opts": opts,
"extra_env": extra_env,
"verbose_ssh": verbose_ssh,
"tty": tty,
},
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
if opts.check:
self._reraise_errors(results)
return results
def run(
self,
cmd: list[str],
opts: RunOpts | None = None,
*,
extra_env: dict[str, str] | None = None,
verbose_ssh: bool = False,
tty: bool = False,
) -> Results:
"""
Command to run on the remote host via ssh
@return a lists of tuples containing Host and the result of the command for this Host
"""
return self._run(
cmd,
opts,
extra_env=extra_env,
verbose_ssh=verbose_ssh,
tty=tty,
)
def run_local(
self,
cmd: list[str],
opts: RunOpts | None = None,
*,
extra_env: dict[str, str] | None = None,
) -> Results:
"""
Command to run locally for each host in the group in parallel
@return a lists of tuples containing Host and the result of the command for this Host
"""
return self._run(
cmd,
opts,
local=True,
extra_env=extra_env,
)
def run_function(
self, func: Callable[[Host], T], check: bool = True
) -> list[HostResult[T]]:
"""
Function to run for each host in the group in parallel
"""
threads = []
results: list[HostResult[T]] = [
HostResult(h, ClanError(f"No result set for thread {i}"))
for (i, h) in enumerate(self.hosts)
]
for i, host in enumerate(self.hosts):
thread = Thread(
target=_worker,
args=(func, host, results, i),
)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
if check:
self._reraise_errors(results)
return results
def filter(self, pred: Callable[[Host], bool]) -> "HostGroup":
"""Return a new Group with the results filtered by the predicate"""
return HostGroup(list(filter(pred, self.hosts)))

View File

@@ -488,7 +488,8 @@ def generate_vars(
raise ClanError(msg) from errors[0][1]
if not was_regenerated and len(machines) > 0:
log.info("All vars are already up to date")
for machine in machines:
machine.info("All vars are already up to date")
return was_regenerated

View File

@@ -12,7 +12,8 @@ pytest_plugins = [
"sshd",
"command",
"ports",
"host_group",
"hosts",
"runtime",
"fixtures_flakes",
"stdout",
"nix_config",

View File

@@ -1,25 +0,0 @@
import os
import pwd
import pytest
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup
from clan_cli.ssh.host_key import HostKeyCheck
from sshd import Sshd
@pytest.fixture
def host_group(sshd: Sshd) -> HostGroup:
login = pwd.getpwuid(os.getuid()).pw_name
group = HostGroup(
[
Host(
"127.0.0.1",
port=sshd.port,
user=login,
key=sshd.key,
host_key_check=HostKeyCheck.NONE,
)
]
)
return group

View File

@@ -0,0 +1,23 @@
import os
import pwd
import pytest
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_key import HostKeyCheck
from sshd import Sshd
@pytest.fixture
def hosts(sshd: Sshd) -> list[Host]:
login = pwd.getpwuid(os.getuid()).pw_name
group = [
Host(
"127.0.0.1",
port=sshd.port,
user=login,
key=sshd.key,
host_key_check=HostKeyCheck.NONE,
)
]
return group

View File

@@ -0,0 +1,7 @@
import pytest
from clan_cli.async_run import AsyncRuntime
@pytest.fixture
def runtime() -> AsyncRuntime:
return AsyncRuntime()

View File

@@ -120,6 +120,7 @@ def test_all_dataclasses() -> None:
excludes = [
"api/__init__.py",
"cmd.py", # We don't want the UI to have access to the cmd module anyway
"async_run.py", # We don't want the UI to have access to the async_run module anyway
]
cli_path = Path("clan_cli").resolve()

View File

@@ -7,7 +7,7 @@ from clan_cli.facts.secret_modules.password_store import SecretStore
from clan_cli.machines.facts import machine_get_fact
from clan_cli.machines.machines import Machine
from clan_cli.nix import nix_shell
from clan_cli.ssh.host_group import HostGroup
from clan_cli.ssh.host import Host
from fixtures_flakes import ClanFlake
from helpers import cli
@@ -17,7 +17,7 @@ def test_upload_secret(
monkeypatch: pytest.MonkeyPatch,
flake: ClanFlake,
temporary_home: Path,
host_group: HostGroup,
hosts: list[Host],
) -> None:
flake.clan_modules = [
"root-password",
@@ -27,7 +27,7 @@ def test_upload_secret(
config = flake.machines["vm1"]
config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
config["clan"]["core"]["networking"]["zerotier"]["controller"]["enable"] = True
host = host_group.hosts[0]
host = hosts[0]
addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}"
config["clan"]["core"]["networking"]["targetHost"] = addr
config["clan"]["user-password"]["user"] = "alice"

View File

@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING
import pytest
from clan_cli.ssh.host_group import HostGroup
from clan_cli.ssh.host import Host
from fixtures_flakes import FlakeForTest
from helpers import cli
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
def test_secrets_upload(
monkeypatch: pytest.MonkeyPatch,
test_flake_with_core: FlakeForTest,
host_group: HostGroup,
hosts: list[Host],
age_keys: list["KeyPair"],
) -> None:
monkeypatch.chdir(test_flake_with_core.path)
@@ -48,7 +48,7 @@ def test_secrets_upload(
)
flake = test_flake_with_core.path.joinpath("flake.nix")
host = host_group.hosts[0]
host = hosts[0]
addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}"
new_text = flake.read_text().replace("__CLAN_TARGET_ADDRESS__", addr)

View File

@@ -1,70 +1,49 @@
from clan_cli.cmd import Log, RunOpts
from clan_cli.async_run import AsyncRuntime
from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup
hosts = HostGroup([Host("some_host")])
host = Host("some_host")
def test_run_environment() -> None:
p2 = hosts.run_local(
def test_run_environment(runtime: AsyncRuntime) -> None:
p2 = runtime.async_run(
None,
host.run_local,
["echo $env_var"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert p2[0].result.stdout == "true\n"
p3 = hosts.run_local(
["env"], RunOpts(shell=True, log=Log.STDERR), extra_env={"env_var": "true"}
assert p2.wait().result.stdout == "true\n"
p3 = runtime.async_run(
None,
host.run_local,
["env"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert "env_var=true" in p3[0].result.stdout
assert "env_var=true" in p3.wait().result.stdout
def test_run_local() -> None:
hosts.run_local(["echo", "hello"])
def test_run_local(runtime: AsyncRuntime) -> None:
p1 = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
)
assert p1.wait().result.stdout == "hello\n"
def test_timeout() -> None:
try:
hosts.run_local(["sleep", "10"], RunOpts(timeout=0.01))
except Exception:
pass
else:
msg = "should have raised TimeoutExpired"
raise AssertionError(msg)
def test_timeout(runtime: AsyncRuntime) -> None:
p1 = runtime.async_run(None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01))
error = p1.wait().error
assert isinstance(error, ClanCmdTimeoutError)
def test_run_function() -> None:
def some_func(h: Host) -> bool:
par = h.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
return par.stdout == "hello\n"
res = hosts.run_function(some_func)
assert res[0].result
def test_run_exception(runtime: AsyncRuntime) -> None:
p1 = runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True))
assert p1.wait().error is not None
def test_run_exception() -> None:
try:
hosts.run_local(["exit 1"], RunOpts(shell=True))
except Exception:
pass
else:
msg = "should have raised Exception"
raise AssertionError(msg)
def test_run_function_exception() -> None:
def some_func(h: Host) -> None:
h.run_local(["exit 1"], RunOpts(shell=True))
try:
hosts.run_function(some_func)
except Exception:
pass
else:
msg = "should have raised Exception"
raise AssertionError(msg)
def test_run_local_non_shell() -> None:
p2 = hosts.run_local(["echo", "1"], RunOpts(log=Log.STDERR))
assert p2[0].result.stdout == "1\n"
def test_run_local_non_shell(runtime: AsyncRuntime) -> None:
p2 = runtime.async_run(None, host.run_local, ["echo", "1"], RunOpts(log=Log.STDERR))
assert p2.wait().result.stdout == "1\n"

View File

@@ -1,8 +1,8 @@
import pytest
from clan_cli.cmd import Log, RunOpts
from clan_cli.async_run import AsyncRuntime
from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts
from clan_cli.errors import ClanError, CmdOut
from clan_cli.ssh.host import Host
from clan_cli.ssh.host_group import HostGroup
from clan_cli.ssh.host_key import HostKeyCheck
from clan_cli.ssh.parse import parse_deployment_address
@@ -20,52 +20,75 @@ def test_parse_ipv6() -> None:
host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT)
def test_run(host_group: HostGroup) -> None:
proc = host_group.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
assert proc[0].result.stdout == "hello\n"
def test_run(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
)
assert proc.wait().result.stdout == "hello\n"
def test_run_environment(host_group: HostGroup) -> None:
p1 = host_group.run(
["echo $env_var"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert p1[0].result.stdout == "true\n"
p2 = host_group.run(["env"], RunOpts(log=Log.STDERR), extra_env={"env_var": "true"})
assert "env_var=true" in p2[0].result.stdout
def test_run_environment(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None,
host.run_local,
["echo $env_var"],
RunOpts(shell=True, log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert proc.wait().result.stdout == "true\n"
for host in hosts:
p2 = runtime.async_run(
None,
host.run_local,
["env"],
RunOpts(log=Log.STDERR),
extra_env={"env_var": "true"},
)
assert "env_var=true" in p2.wait().result.stdout
def test_run_no_shell(host_group: HostGroup) -> None:
proc = host_group.run(["echo", "$hello"], RunOpts(log=Log.STDERR))
assert proc[0].result.stdout == "$hello\n"
def test_run_no_shell(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
)
assert proc.wait().result.stdout == "hello\n"
def test_run_function(host_group: HostGroup) -> None:
def test_run_function(hosts: list[Host], runtime: AsyncRuntime) -> None:
def some_func(h: Host) -> bool:
p = h.run(["echo", "hello"])
return p.stdout == "hello\n"
res = host_group.run_function(some_func)
assert res[0].result
for host in hosts:
proc = runtime.async_run(None, some_func, host)
assert proc.wait().result
def test_timeout(host_group: HostGroup) -> None:
try:
host_group.run_local(["sleep", "10"], RunOpts(timeout=0.01))
except Exception:
pass
else:
msg = "should have raised TimeoutExpired"
raise AssertionError(msg)
def test_timeout(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01)
)
error = proc.wait().error
assert isinstance(error, ClanCmdTimeoutError)
def test_run_exception(host_group: HostGroup) -> None:
r = host_group.run(["exit 1"], RunOpts(check=False, shell=True))
assert r[0].result.returncode == 1
def test_run_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
for host in hosts:
proc = runtime.async_run(
None, host.run_local, ["exit 1"], RunOpts(shell=True, check=False)
)
assert proc.wait().result.returncode == 1
try:
host_group.run(["exit 1"], RunOpts(shell=True))
for host in hosts:
runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True))
runtime.join_all()
runtime.check_all()
except Exception:
pass
else:
@@ -73,12 +96,15 @@ def test_run_exception(host_group: HostGroup) -> None:
raise AssertionError(msg)
def test_run_function_exception(host_group: HostGroup) -> None:
def test_run_function_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
def some_func(h: Host) -> CmdOut:
return h.run_local(["exit 1"], RunOpts(shell=True))
try:
host_group.run_function(some_func)
for host in hosts:
runtime.async_run(None, some_func, host)
runtime.join_all()
runtime.check_all()
except Exception:
pass
else: