clan-cli: Replace HostGroup and MachineGroup with generic AsyncRuntime class. Propagate cmd prefix over thread local. Close threads on CTRL+C
This commit is contained in:
@@ -434,8 +434,8 @@ def main() -> None:
|
|||||||
else:
|
else:
|
||||||
log.error("%s", e) # noqa: TRY400
|
log.error("%s", e) # noqa: TRY400
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt as ex:
|
||||||
log.warning("Interrupted by user")
|
log.warning("Interrupted by user", exc_info=ex)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
312
pkgs/clan-cli/clan_cli/async_run.py
Normal file
312
pkgs/clan-cli/clan_cli/async_run.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import IO, Any, Generic, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
from clan_cli.errors import ClanError
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Why did we create a custom AsyncRuntime instead of using asyncio?
|
||||||
|
#
|
||||||
|
# The AsyncRuntime class allows us to run functions in separate threads for asynchronous
|
||||||
|
# execution without requiring the use of the "async" keyword. By using this approach,
|
||||||
|
# functions can gracefully handle cancellation by checking if get_async_ctx().cancel is True.
|
||||||
|
#
|
||||||
|
# There was some resistance to using asyncio, partly due to challenges we faced when
|
||||||
|
# implementing it in our first web interface. Threads felt simpler and more familiar
|
||||||
|
# for our use case.
|
||||||
|
#
|
||||||
|
# That said, asyncio is generally more efficient because it uses non-blocking I/O
|
||||||
|
# and avoids issues with Python's Global Interpreter Lock (GIL), which can limit the
|
||||||
|
# performance of threads in CPU-heavy workloads.
|
||||||
|
#
|
||||||
|
# Using threads works well for us because most of the time is spent waiting for commands
|
||||||
|
# or external processes to finish, rather than performing heavy computing tasks.
|
||||||
|
#
|
||||||
|
# Note: Starting with Python 3.14, the GIL can be disabled to enable true parallelism.
|
||||||
|
# However, disabling the GIL introduces a 10-40% performance cost for Python code
|
||||||
|
# due to the overhead of additional locking.
|
||||||
|
|
||||||
|
# Define generics for return type and call signature
|
||||||
|
R = TypeVar("R") # Return type of the callable
|
||||||
|
P = ParamSpec("P") # Parameters of the callable
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AsyncResult(Generic[R]):
|
||||||
|
_result: R | Exception
|
||||||
|
|
||||||
|
@property
|
||||||
|
def error(self) -> Exception | None:
|
||||||
|
"""
|
||||||
|
Returns an error if the callable raised an exception.
|
||||||
|
"""
|
||||||
|
if isinstance(self._result, Exception):
|
||||||
|
return self._result
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self) -> R:
|
||||||
|
"""
|
||||||
|
Unwraps and returns the result if no exception occurred.
|
||||||
|
Raises the exception otherwise.
|
||||||
|
"""
|
||||||
|
if isinstance(self._result, Exception):
|
||||||
|
raise self._result
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AsyncContext:
|
||||||
|
"""
|
||||||
|
This class stores thread-local data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prefix: str | None = None # prefix for logging
|
||||||
|
stdout: IO[bytes] | None = None # stdout of subprocesses
|
||||||
|
stderr: IO[bytes] | None = None # stderr of subprocesses
|
||||||
|
cancel: bool = False # Used to signal cancellation of task
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AsyncOpts:
|
||||||
|
"""
|
||||||
|
Options for the async_run function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tid: str | None = None
|
||||||
|
check: bool = True
|
||||||
|
async_ctx: AsyncContext = field(default_factory=AsyncContext)
|
||||||
|
|
||||||
|
|
||||||
|
ASYNC_CTX_THREAD_LOCAL = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_cancelled() -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current task has been cancelled.
|
||||||
|
"""
|
||||||
|
return get_async_ctx().cancel
|
||||||
|
|
||||||
|
|
||||||
|
def get_async_ctx() -> AsyncContext:
|
||||||
|
"""
|
||||||
|
Retrieve the current AsyncContext, creating a new one if none exists.
|
||||||
|
"""
|
||||||
|
global ASYNC_CTX_THREAD_LOCAL
|
||||||
|
|
||||||
|
if not hasattr(ASYNC_CTX_THREAD_LOCAL, "async_ctx"):
|
||||||
|
ASYNC_CTX_THREAD_LOCAL.async_ctx = AsyncContext()
|
||||||
|
return ASYNC_CTX_THREAD_LOCAL.async_ctx
|
||||||
|
|
||||||
|
|
||||||
|
def set_async_ctx(ctx: AsyncContext) -> None:
|
||||||
|
global ASYNC_CTX_THREAD_LOCAL
|
||||||
|
ASYNC_CTX_THREAD_LOCAL.async_ctx = ctx
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncThread(threading.Thread, Generic[P, R]):
|
||||||
|
function: Callable[P, R]
|
||||||
|
args: Any
|
||||||
|
kwargs: Any
|
||||||
|
result: AsyncResult[R] | None
|
||||||
|
finished: bool
|
||||||
|
condition: threading.Condition
|
||||||
|
async_opts: AsyncOpts
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
async_opts: AsyncOpts,
|
||||||
|
condition: threading.Condition,
|
||||||
|
function: Callable[P, R],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
A threaded wrapper for running a function asynchronously.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.function = function
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.result: AsyncResult[R] | None = None # Store the result or exception
|
||||||
|
self.finished = False # Set to True after the thread finishes execution
|
||||||
|
self.condition = condition # Shared condition variable
|
||||||
|
self.async_opts = async_opts
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
Run the function in a separate thread.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
set_async_ctx(self.async_opts.async_ctx)
|
||||||
|
self.result = AsyncResult(_result=self.function(*self.args, **self.kwargs))
|
||||||
|
except Exception as ex:
|
||||||
|
self.result = AsyncResult(_result=ex)
|
||||||
|
finally:
|
||||||
|
self.finished = True
|
||||||
|
# Acquire the condition lock before notifying
|
||||||
|
with self.condition:
|
||||||
|
self.condition.notify_all() # Notify waiting threads that this thread is done
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AsyncFuture(Generic[R]):
|
||||||
|
_tid: str
|
||||||
|
_runtime: "AsyncRuntime"
|
||||||
|
|
||||||
|
def wait(self) -> AsyncResult[R]:
|
||||||
|
"""
|
||||||
|
Wait for the task to finish.
|
||||||
|
"""
|
||||||
|
if self._tid not in self._runtime.tasks:
|
||||||
|
msg = f"No task with the name '{self._tid}' exists."
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
thread = self._runtime.tasks[self._tid]
|
||||||
|
thread.join()
|
||||||
|
result = self.get_result()
|
||||||
|
if result is None:
|
||||||
|
msg = f"Task '{self._tid}' unexpectedly returned None."
|
||||||
|
raise ClanError(msg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_result(self) -> AsyncResult[R] | None:
|
||||||
|
"""
|
||||||
|
Retrieve the result of a finished task and remove it from the task list.
|
||||||
|
"""
|
||||||
|
if self._tid not in self._runtime.tasks:
|
||||||
|
msg = f"No task with the name '{self._tid}' exists."
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
thread = self._runtime.tasks[self._tid]
|
||||||
|
|
||||||
|
if not thread.finished:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Remove the task after retrieving the result
|
||||||
|
result = thread.result
|
||||||
|
del self._runtime.tasks[self._tid]
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
msg = f"The result for task '{self._tid}' is unexpectedly None."
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AsyncRuntime:
|
||||||
|
tasks: dict[str, AsyncThread[Any, Any]] = field(default_factory=dict)
|
||||||
|
condition: threading.Condition = field(default_factory=threading.Condition)
|
||||||
|
|
||||||
|
def async_run(
|
||||||
|
self,
|
||||||
|
opts: AsyncOpts | None,
|
||||||
|
function: Callable[P, R],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> AsyncFuture[R]:
|
||||||
|
"""
|
||||||
|
Run the given function asynchronously in a thread with a specific name and arguments.
|
||||||
|
The function's static typing is preserved.
|
||||||
|
"""
|
||||||
|
if opts is None:
|
||||||
|
opts = AsyncOpts()
|
||||||
|
|
||||||
|
if opts.tid is None:
|
||||||
|
opts.tid = uuid.uuid4().hex
|
||||||
|
|
||||||
|
if opts.tid in self.tasks:
|
||||||
|
msg = f"A task with the name '{opts.tid}' is already running."
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
# Create and start the new AsyncThread
|
||||||
|
thread = AsyncThread(opts, self.condition, function, *args, **kwargs)
|
||||||
|
self.tasks[opts.tid] = thread
|
||||||
|
thread.start()
|
||||||
|
return AsyncFuture(opts.tid, self)
|
||||||
|
|
||||||
|
def join_all(self) -> None:
|
||||||
|
"""
|
||||||
|
Wait for all tasks to finish
|
||||||
|
"""
|
||||||
|
with self.condition:
|
||||||
|
while any(
|
||||||
|
not task.finished for task in self.tasks.values()
|
||||||
|
): # Check if any tasks are still running
|
||||||
|
self.condition.wait() # Wait until a thread signals completion
|
||||||
|
|
||||||
|
def check_all(self) -> None:
|
||||||
|
"""
|
||||||
|
Check if there where any errors
|
||||||
|
"""
|
||||||
|
err_count = 0
|
||||||
|
|
||||||
|
for name, task in self.tasks.items():
|
||||||
|
if task.finished and task.async_opts.check:
|
||||||
|
assert task.result is not None
|
||||||
|
error = task.result.error
|
||||||
|
if log.isEnabledFor(logging.DEBUG):
|
||||||
|
log.error(
|
||||||
|
f"failed with error: {error}",
|
||||||
|
extra={"command_prefix": name},
|
||||||
|
exc_info=error,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log.error(
|
||||||
|
f"failed with error: {error}", extra={"command_prefix": name}
|
||||||
|
)
|
||||||
|
err_count += 1
|
||||||
|
|
||||||
|
if err_count > 0:
|
||||||
|
msg = f"{err_count} hosts failed with an error. Check the logs above"
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
def __enter__(self) -> "AsyncRuntime":
|
||||||
|
"""
|
||||||
|
Enter the runtime context related to this object.
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_value: BaseException | None,
|
||||||
|
traceback: types.TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Exit the runtime context related to this object.
|
||||||
|
Sets async_ctx.cancel to True to signal cancellation.
|
||||||
|
"""
|
||||||
|
for name, task in self.tasks.items():
|
||||||
|
if not task.finished:
|
||||||
|
task.async_opts.async_ctx.cancel = True
|
||||||
|
log.debug(f"Canceling task {name}")
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
runtime = AsyncRuntime()
|
||||||
|
|
||||||
|
def add(a: int, b: int) -> int:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
def concatenate(a: str, b: str) -> str:
|
||||||
|
time.sleep(1)
|
||||||
|
msg = "Hello World"
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
with runtime:
|
||||||
|
p1 = runtime.async_run(None, add, 1, 2)
|
||||||
|
p2 = runtime.async_run(None, concatenate, "Hello ", "World")
|
||||||
|
|
||||||
|
add_result = p1.wait()
|
||||||
|
print(add_result.result) # Output: 3
|
||||||
|
concat_result = p2.wait()
|
||||||
|
print(concat_result.error) # Output: Hello World
|
||||||
@@ -17,6 +17,7 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Any
|
from typing import IO, Any
|
||||||
|
|
||||||
|
from clan_cli.async_run import get_async_ctx, is_async_cancelled
|
||||||
from clan_cli.colors import Color
|
from clan_cli.colors import Color
|
||||||
from clan_cli.custom_logger import print_trace
|
from clan_cli.custom_logger import print_trace
|
||||||
from clan_cli.errors import ClanCmdError, ClanError, CmdOut, indent_command
|
from clan_cli.errors import ClanCmdError, ClanError, CmdOut, indent_command
|
||||||
@@ -87,7 +88,8 @@ def handle_io(
|
|||||||
stdout_extra = {}
|
stdout_extra = {}
|
||||||
stderr_extra = {}
|
stderr_extra = {}
|
||||||
if prefix:
|
if prefix:
|
||||||
stdout_extra["command_prefix"] = stderr_extra["command_prefix"] = prefix
|
stdout_extra["command_prefix"] = prefix
|
||||||
|
stderr_extra["command_prefix"] = prefix
|
||||||
if msg_color and msg_color.stderr:
|
if msg_color and msg_color.stderr:
|
||||||
stdout_extra["color"] = msg_color.stderr.value
|
stdout_extra["color"] = msg_color.stderr.value
|
||||||
if msg_color and msg_color.stdout:
|
if msg_color and msg_color.stdout:
|
||||||
@@ -101,6 +103,13 @@ def handle_io(
|
|||||||
description = prefix
|
description = prefix
|
||||||
raise ClanCmdTimeoutError(msg=msg, description=description, timeout=timeout)
|
raise ClanCmdTimeoutError(msg=msg, description=description, timeout=timeout)
|
||||||
|
|
||||||
|
# Check if the command has been cancelled
|
||||||
|
if is_async_cancelled():
|
||||||
|
cmdlog.warning("Command cancelled", extra=stderr_extra)
|
||||||
|
|
||||||
|
# Terminate process
|
||||||
|
break
|
||||||
|
|
||||||
# Wait for data to be available
|
# Wait for data to be available
|
||||||
readlist, writelist, _ = select.select(rlist, wlist, [], 0.1)
|
readlist, writelist, _ = select.select(rlist, wlist, [], 0.1)
|
||||||
if len(readlist) == 0 and len(writelist) == 0:
|
if len(readlist) == 0 and len(writelist) == 0:
|
||||||
@@ -116,7 +125,7 @@ def handle_io(
|
|||||||
|
|
||||||
# If Log.STDOUT is set, log the stdout output
|
# If Log.STDOUT is set, log the stdout output
|
||||||
if ret and log in [Log.STDOUT, Log.BOTH]:
|
if ret and log in [Log.STDOUT, Log.BOTH]:
|
||||||
lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
|
lines = ret.decode("utf-8", "replace").rstrip("\n").rstrip().split("\n")
|
||||||
for line in lines:
|
for line in lines:
|
||||||
cmdlog.info(line, extra=stdout_extra)
|
cmdlog.info(line, extra=stdout_extra)
|
||||||
|
|
||||||
@@ -133,7 +142,7 @@ def handle_io(
|
|||||||
|
|
||||||
# If Log.STDERR is set, log the stderr output
|
# If Log.STDERR is set, log the stderr output
|
||||||
if ret and log in [Log.STDERR, Log.BOTH]:
|
if ret and log in [Log.STDERR, Log.BOTH]:
|
||||||
lines = ret.decode("utf-8", "replace").rstrip("\n").split("\n")
|
lines = ret.decode("utf-8", "replace").rstrip("\n").rstrip().split("\n")
|
||||||
for line in lines:
|
for line in lines:
|
||||||
cmdlog.info(line, extra=stderr_extra)
|
cmdlog.info(line, extra=stderr_extra)
|
||||||
|
|
||||||
@@ -273,6 +282,17 @@ def run(
|
|||||||
if options.cwd is None:
|
if options.cwd is None:
|
||||||
options.cwd = Path.cwd()
|
options.cwd = Path.cwd()
|
||||||
|
|
||||||
|
async_ctx = get_async_ctx()
|
||||||
|
# Fill in the options from the thread-local data
|
||||||
|
# if they are not set in the options
|
||||||
|
if async_ctx:
|
||||||
|
if options.prefix is None:
|
||||||
|
options.prefix = async_ctx.prefix
|
||||||
|
if options.stdout is None:
|
||||||
|
options.stdout = async_ctx.stdout
|
||||||
|
if options.stderr is None:
|
||||||
|
options.stderr = async_ctx.stderr
|
||||||
|
|
||||||
if options.input:
|
if options.input:
|
||||||
if any(not ch.isprintable() for ch in options.input.decode("ascii", "replace")):
|
if any(not ch.isprintable() for ch in options.input.decode("ascii", "replace")):
|
||||||
filtered_input = "<<binary_blob>>"
|
filtered_input = "<<binary_blob>>"
|
||||||
@@ -318,7 +338,8 @@ def run(
|
|||||||
stdout=options.stdout,
|
stdout=options.stdout,
|
||||||
stderr=options.stderr,
|
stderr=options.stderr,
|
||||||
)
|
)
|
||||||
process.wait()
|
if not is_async_cancelled():
|
||||||
|
process.wait()
|
||||||
|
|
||||||
global TIME_TABLE
|
global TIME_TABLE
|
||||||
if TIME_TABLE:
|
if TIME_TABLE:
|
||||||
@@ -336,7 +357,9 @@ def run(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if options.check and process.returncode != 0:
|
if options.check and process.returncode != 0:
|
||||||
raise ClanCmdError(cmd_out)
|
err = ClanCmdError(cmd_out)
|
||||||
|
err.msg = "Command has been cancelled"
|
||||||
|
raise err
|
||||||
|
|
||||||
return cmd_out
|
return cmd_out
|
||||||
|
|
||||||
@@ -359,3 +382,6 @@ def run_no_stdout(
|
|||||||
cmd,
|
cmd,
|
||||||
opts,
|
opts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# type: ignore
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ class PrefixFormatter(logging.Formatter):
|
|||||||
|
|
||||||
# If command_prefix is set, color the prefix with a unique color.
|
# If command_prefix is set, color the prefix with a unique color.
|
||||||
elif command_prefix:
|
elif command_prefix:
|
||||||
|
command_prefix = command_prefix[
|
||||||
|
:10
|
||||||
|
] # Truncate the command prefix to 10 characters.
|
||||||
prefix_color = self.hostname_colorcode(command_prefix)
|
prefix_color = self.hostname_colorcode(command_prefix)
|
||||||
format_str = color_by_tuple(f"[{command_prefix}]", fg=prefix_color)
|
format_str = color_by_tuple(f"[{command_prefix}]", fg=prefix_color)
|
||||||
format_str += color_by_tuple(" %(message)s", fg=msg_color)
|
format_str += color_by_tuple(" %(message)s", fg=msg_color)
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
from clan_cli.ssh.host import Host
|
|
||||||
from clan_cli.ssh.host_group import HostGroup, HostResult
|
|
||||||
|
|
||||||
from .machines import Machine
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class MachineGroup:
|
|
||||||
def __init__(self, machines: list[Machine]) -> None:
|
|
||||||
self.machines = machines
|
|
||||||
self.group = HostGroup([m.target_host for m in machines])
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return str(self)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f"MachineGroup({self.group})"
|
|
||||||
|
|
||||||
def run_function(
|
|
||||||
self, func: Callable[[Machine], T], check: bool = True
|
|
||||||
) -> list[HostResult[T]]:
|
|
||||||
"""
|
|
||||||
Function to run for each host in the group in parallel
|
|
||||||
|
|
||||||
@func the function to call
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrapped_func(host: Host) -> T:
|
|
||||||
return func(host.meta["machine"])
|
|
||||||
|
|
||||||
return self.group.run_function(wrapped_func, check=check)
|
|
||||||
@@ -6,6 +6,7 @@ import shlex
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from clan_cli.api import API
|
from clan_cli.api import API
|
||||||
|
from clan_cli.async_run import AsyncContext, AsyncOpts, AsyncRuntime, is_async_cancelled
|
||||||
from clan_cli.clan_uri import FlakeId
|
from clan_cli.clan_uri import FlakeId
|
||||||
from clan_cli.cmd import MsgColor, RunOpts, run
|
from clan_cli.cmd import MsgColor, RunOpts, run
|
||||||
from clan_cli.colors import AnsiColor
|
from clan_cli.colors import AnsiColor
|
||||||
@@ -24,7 +25,6 @@ from clan_cli.vars.generate import generate_vars
|
|||||||
from clan_cli.vars.upload import upload_secret_vars
|
from clan_cli.vars.upload import upload_secret_vars
|
||||||
|
|
||||||
from .inventory import get_all_machines, get_selected_machines
|
from .inventory import get_all_machines, get_selected_machines
|
||||||
from .machine_group import MachineGroup
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -113,10 +113,10 @@ def update_machines(base_path: str, machines: list[InventoryMachine]) -> None:
|
|||||||
# m.override_build_host = machine.deploy.buildHost
|
# m.override_build_host = machine.deploy.buildHost
|
||||||
group_machines.append(m)
|
group_machines.append(m)
|
||||||
|
|
||||||
deploy_machine(MachineGroup(group_machines))
|
deploy_machine(group_machines)
|
||||||
|
|
||||||
|
|
||||||
def deploy_machine(machines: MachineGroup) -> None:
|
def deploy_machine(machines: list[Machine]) -> None:
|
||||||
"""
|
"""
|
||||||
Deploy to all hosts in parallel
|
Deploy to all hosts in parallel
|
||||||
"""
|
"""
|
||||||
@@ -163,11 +163,9 @@ def deploy_machine(machines: MachineGroup) -> None:
|
|||||||
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
|
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
|
||||||
extra_env=env,
|
extra_env=env,
|
||||||
)
|
)
|
||||||
ret = host.run(
|
|
||||||
switch_cmd,
|
if is_async_cancelled():
|
||||||
RunOpts(check=False, msg_color=MsgColor(stderr=AnsiColor.DEFAULT)),
|
return
|
||||||
extra_env=env,
|
|
||||||
)
|
|
||||||
|
|
||||||
# if the machine is mobile, we retry to deploy with the mobile workaround method
|
# if the machine is mobile, we retry to deploy with the mobile workaround method
|
||||||
is_mobile = machine.deployment.get("nixosMobileWorkaround", False)
|
is_mobile = machine.deployment.get("nixosMobileWorkaround", False)
|
||||||
@@ -187,62 +185,72 @@ def deploy_machine(machines: MachineGroup) -> None:
|
|||||||
extra_env=env,
|
extra_env=env,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(machines.group.hosts) > 1:
|
with AsyncRuntime() as runtime:
|
||||||
machines.run_function(deploy)
|
for machine in machines:
|
||||||
else:
|
machine.info(f"Updating {machine.name}")
|
||||||
deploy(machines.machines[0])
|
runtime.async_run(
|
||||||
|
AsyncOpts(
|
||||||
|
tid=machine.name, async_ctx=AsyncContext(prefix=machine.name)
|
||||||
|
),
|
||||||
|
deploy,
|
||||||
|
machine,
|
||||||
|
)
|
||||||
|
runtime.join_all()
|
||||||
|
|
||||||
|
|
||||||
def update(args: argparse.Namespace) -> None:
|
def update_command(args: argparse.Namespace) -> None:
|
||||||
if args.flake is None:
|
try:
|
||||||
msg = "Could not find clan flake toplevel directory"
|
if args.flake is None:
|
||||||
raise ClanError(msg)
|
msg = "Could not find clan flake toplevel directory"
|
||||||
machines = []
|
raise ClanError(msg)
|
||||||
if len(args.machines) == 1 and args.target_host is not None:
|
machines = []
|
||||||
machine = Machine(
|
if len(args.machines) == 1 and args.target_host is not None:
|
||||||
name=args.machines[0], flake=args.flake, nix_options=args.option
|
machine = Machine(
|
||||||
)
|
name=args.machines[0], flake=args.flake, nix_options=args.option
|
||||||
machine.override_target_host = args.target_host
|
)
|
||||||
machine.override_build_host = args.build_host
|
machine.override_target_host = args.target_host
|
||||||
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
machine.override_build_host = args.build_host
|
||||||
machines.append(machine)
|
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
||||||
|
machines.append(machine)
|
||||||
elif args.target_host is not None:
|
|
||||||
print("target host can only be specified for a single machine")
|
|
||||||
exit(1)
|
|
||||||
else:
|
|
||||||
if len(args.machines) == 0:
|
|
||||||
ignored_machines = []
|
|
||||||
for machine in get_all_machines(args.flake, args.option):
|
|
||||||
if machine.deployment.get("requireExplicitUpdate", False):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
machine.build_host # noqa: B018
|
|
||||||
except ClanError: # check if we have a build host set
|
|
||||||
ignored_machines.append(machine)
|
|
||||||
continue
|
|
||||||
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
|
||||||
machine.override_build_host = args.build_host
|
|
||||||
machines.append(machine)
|
|
||||||
|
|
||||||
if not machines and ignored_machines != []:
|
|
||||||
print(
|
|
||||||
"WARNING: No machines to update."
|
|
||||||
"The following defined machines were ignored because they"
|
|
||||||
"do not have the `clan.core.networking.targetHost` nixos option set:",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
for machine in ignored_machines:
|
|
||||||
print(machine, file=sys.stderr)
|
|
||||||
|
|
||||||
|
elif args.target_host is not None:
|
||||||
|
print("target host can only be specified for a single machine")
|
||||||
|
exit(1)
|
||||||
else:
|
else:
|
||||||
machines = get_selected_machines(args.flake, args.option, args.machines)
|
if len(args.machines) == 0:
|
||||||
for machine in machines:
|
ignored_machines = []
|
||||||
machine.override_build_host = args.build_host
|
for machine in get_all_machines(args.flake, args.option):
|
||||||
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
if machine.deployment.get("requireExplicitUpdate", False):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
machine.build_host # noqa: B018
|
||||||
|
except ClanError: # check if we have a build host set
|
||||||
|
ignored_machines.append(machine)
|
||||||
|
continue
|
||||||
|
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
||||||
|
machine.override_build_host = args.build_host
|
||||||
|
machines.append(machine)
|
||||||
|
|
||||||
host_group = MachineGroup(machines)
|
if not machines and ignored_machines != []:
|
||||||
deploy_machine(host_group)
|
print(
|
||||||
|
"WARNING: No machines to update."
|
||||||
|
"The following defined machines were ignored because they"
|
||||||
|
"do not have the `clan.core.networking.targetHost` nixos option set:",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
for machine in ignored_machines:
|
||||||
|
print(machine, file=sys.stderr)
|
||||||
|
|
||||||
|
else:
|
||||||
|
machines = get_selected_machines(args.flake, args.option, args.machines)
|
||||||
|
for machine in machines:
|
||||||
|
machine.override_build_host = args.build_host
|
||||||
|
machine.host_key_check = HostKeyCheck.from_str(args.host_key_check)
|
||||||
|
|
||||||
|
deploy_machine(machines)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
log.warning("Interrupted by user")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def register_update_parser(parser: argparse.ArgumentParser) -> None:
|
def register_update_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
@@ -272,4 +280,4 @@ def register_update_parser(parser: argparse.ArgumentParser) -> None:
|
|||||||
type=str,
|
type=str,
|
||||||
help="Address of the machine to build the flake, in the format of user@host:1234.",
|
help="Address of the machine to build the flake, in the format of user@host:1234.",
|
||||||
)
|
)
|
||||||
parser.set_defaults(func=update)
|
parser.set_defaults(func=update_command)
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ def nix_add_to_gcroots(nix_path: Path, dest: Path) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def nix_config() -> dict[str, Any]:
|
def nix_config() -> dict[str, Any]:
|
||||||
cmd = nix_command(["show-config", "--json"])
|
cmd = nix_command(["config", "show", "--json"])
|
||||||
proc = run_no_stdout(cmd)
|
proc = run_no_stdout(cmd)
|
||||||
data = json.loads(proc.stdout)
|
data = json.loads(proc.stdout)
|
||||||
config = {}
|
config = {}
|
||||||
|
|||||||
@@ -1,214 +0,0 @@
|
|||||||
import logging
|
|
||||||
from collections.abc import Callable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from threading import Thread
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from clan_cli.cmd import RunOpts
|
|
||||||
from clan_cli.errors import ClanError
|
|
||||||
from clan_cli.ssh import T
|
|
||||||
from clan_cli.ssh.host import Host
|
|
||||||
from clan_cli.ssh.results import HostResult, Results
|
|
||||||
|
|
||||||
cmdlog = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _worker(
|
|
||||||
func: Callable[[Host], T],
|
|
||||||
host: Host,
|
|
||||||
results: list[HostResult[T]],
|
|
||||||
idx: int,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
results[idx] = HostResult(host, func(host))
|
|
||||||
except Exception as e:
|
|
||||||
results[idx] = HostResult(host, e)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HostGroup:
|
|
||||||
hosts: list[Host]
|
|
||||||
|
|
||||||
def _run_local(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
cmd: list[str],
|
|
||||||
opts: RunOpts,
|
|
||||||
extra_env: dict[str, str] | None,
|
|
||||||
host: Host,
|
|
||||||
results: Results,
|
|
||||||
verbose_ssh: bool = False,
|
|
||||||
tty: bool = False,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
proc = host.run_local(
|
|
||||||
cmd,
|
|
||||||
opts,
|
|
||||||
extra_env,
|
|
||||||
)
|
|
||||||
results.append(HostResult(host, proc))
|
|
||||||
except Exception as e:
|
|
||||||
results.append(HostResult(host, e))
|
|
||||||
|
|
||||||
def _run_remote(
|
|
||||||
self,
|
|
||||||
cmd: list[str],
|
|
||||||
opts: RunOpts,
|
|
||||||
host: Host,
|
|
||||||
results: Results,
|
|
||||||
extra_env: dict[str, str] | None = None,
|
|
||||||
verbose_ssh: bool = False,
|
|
||||||
tty: bool = False,
|
|
||||||
) -> None:
|
|
||||||
try:
|
|
||||||
proc = host.run(
|
|
||||||
cmd,
|
|
||||||
extra_env=extra_env,
|
|
||||||
verbose_ssh=verbose_ssh,
|
|
||||||
opts=opts,
|
|
||||||
tty=tty,
|
|
||||||
)
|
|
||||||
results.append(HostResult(host, proc))
|
|
||||||
except Exception as e:
|
|
||||||
results.append(HostResult(host, e))
|
|
||||||
|
|
||||||
def _reraise_errors(self, results: list[HostResult[Any]]) -> None:
|
|
||||||
errors = 0
|
|
||||||
for result in results:
|
|
||||||
e = result.error
|
|
||||||
if e:
|
|
||||||
cmdlog.error(
|
|
||||||
f"failed with: {e}",
|
|
||||||
extra={"command_prefix": result.host.command_prefix},
|
|
||||||
)
|
|
||||||
errors += 1
|
|
||||||
if errors > 0:
|
|
||||||
msg = f"{errors} hosts failed with an error. Check the logs above"
|
|
||||||
raise ClanError(msg) from e
|
|
||||||
|
|
||||||
def _run(
|
|
||||||
self,
|
|
||||||
cmd: list[str],
|
|
||||||
opts: RunOpts | None = None,
|
|
||||||
local: bool = False,
|
|
||||||
extra_env: dict[str, str] | None = None,
|
|
||||||
verbose_ssh: bool = False,
|
|
||||||
tty: bool = False,
|
|
||||||
) -> Results:
|
|
||||||
results: Results = []
|
|
||||||
threads = []
|
|
||||||
|
|
||||||
if opts is None:
|
|
||||||
opts = RunOpts()
|
|
||||||
|
|
||||||
for host in self.hosts:
|
|
||||||
if local:
|
|
||||||
thread = Thread(
|
|
||||||
target=self._run_local,
|
|
||||||
kwargs={
|
|
||||||
"cmd": cmd,
|
|
||||||
"opts": opts,
|
|
||||||
"host": host,
|
|
||||||
"results": results,
|
|
||||||
"extra_env": extra_env,
|
|
||||||
"verbose_ssh": verbose_ssh,
|
|
||||||
"tty": tty,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
thread = Thread(
|
|
||||||
target=self._run_remote,
|
|
||||||
kwargs={
|
|
||||||
"results": results,
|
|
||||||
"cmd": cmd,
|
|
||||||
"host": host,
|
|
||||||
"opts": opts,
|
|
||||||
"extra_env": extra_env,
|
|
||||||
"verbose_ssh": verbose_ssh,
|
|
||||||
"tty": tty,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
thread.start()
|
|
||||||
threads.append(thread)
|
|
||||||
|
|
||||||
for thread in threads:
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
if opts.check:
|
|
||||||
self._reraise_errors(results)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
cmd: list[str],
|
|
||||||
opts: RunOpts | None = None,
|
|
||||||
*,
|
|
||||||
extra_env: dict[str, str] | None = None,
|
|
||||||
verbose_ssh: bool = False,
|
|
||||||
tty: bool = False,
|
|
||||||
) -> Results:
|
|
||||||
"""
|
|
||||||
Command to run on the remote host via ssh
|
|
||||||
|
|
||||||
@return a lists of tuples containing Host and the result of the command for this Host
|
|
||||||
"""
|
|
||||||
return self._run(
|
|
||||||
cmd,
|
|
||||||
opts,
|
|
||||||
extra_env=extra_env,
|
|
||||||
verbose_ssh=verbose_ssh,
|
|
||||||
tty=tty,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_local(
|
|
||||||
self,
|
|
||||||
cmd: list[str],
|
|
||||||
opts: RunOpts | None = None,
|
|
||||||
*,
|
|
||||||
extra_env: dict[str, str] | None = None,
|
|
||||||
) -> Results:
|
|
||||||
"""
|
|
||||||
Command to run locally for each host in the group in parallel
|
|
||||||
|
|
||||||
@return a lists of tuples containing Host and the result of the command for this Host
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self._run(
|
|
||||||
cmd,
|
|
||||||
opts,
|
|
||||||
local=True,
|
|
||||||
extra_env=extra_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_function(
|
|
||||||
self, func: Callable[[Host], T], check: bool = True
|
|
||||||
) -> list[HostResult[T]]:
|
|
||||||
"""
|
|
||||||
Function to run for each host in the group in parallel
|
|
||||||
"""
|
|
||||||
threads = []
|
|
||||||
results: list[HostResult[T]] = [
|
|
||||||
HostResult(h, ClanError(f"No result set for thread {i}"))
|
|
||||||
for (i, h) in enumerate(self.hosts)
|
|
||||||
]
|
|
||||||
for i, host in enumerate(self.hosts):
|
|
||||||
thread = Thread(
|
|
||||||
target=_worker,
|
|
||||||
args=(func, host, results, i),
|
|
||||||
)
|
|
||||||
threads.append(thread)
|
|
||||||
|
|
||||||
for thread in threads:
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
for thread in threads:
|
|
||||||
thread.join()
|
|
||||||
if check:
|
|
||||||
self._reraise_errors(results)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def filter(self, pred: Callable[[Host], bool]) -> "HostGroup":
|
|
||||||
"""Return a new Group with the results filtered by the predicate"""
|
|
||||||
return HostGroup(list(filter(pred, self.hosts)))
|
|
||||||
@@ -488,7 +488,8 @@ def generate_vars(
|
|||||||
raise ClanError(msg) from errors[0][1]
|
raise ClanError(msg) from errors[0][1]
|
||||||
|
|
||||||
if not was_regenerated and len(machines) > 0:
|
if not was_regenerated and len(machines) > 0:
|
||||||
log.info("All vars are already up to date")
|
for machine in machines:
|
||||||
|
machine.info("All vars are already up to date")
|
||||||
|
|
||||||
return was_regenerated
|
return was_regenerated
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ pytest_plugins = [
|
|||||||
"sshd",
|
"sshd",
|
||||||
"command",
|
"command",
|
||||||
"ports",
|
"ports",
|
||||||
"host_group",
|
"hosts",
|
||||||
|
"runtime",
|
||||||
"fixtures_flakes",
|
"fixtures_flakes",
|
||||||
"stdout",
|
"stdout",
|
||||||
"nix_config",
|
"nix_config",
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
import os
|
|
||||||
import pwd
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from clan_cli.ssh.host import Host
|
|
||||||
from clan_cli.ssh.host_group import HostGroup
|
|
||||||
from clan_cli.ssh.host_key import HostKeyCheck
|
|
||||||
from sshd import Sshd
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def host_group(sshd: Sshd) -> HostGroup:
|
|
||||||
login = pwd.getpwuid(os.getuid()).pw_name
|
|
||||||
group = HostGroup(
|
|
||||||
[
|
|
||||||
Host(
|
|
||||||
"127.0.0.1",
|
|
||||||
port=sshd.port,
|
|
||||||
user=login,
|
|
||||||
key=sshd.key,
|
|
||||||
host_key_check=HostKeyCheck.NONE,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return group
|
|
||||||
23
pkgs/clan-cli/tests/hosts.py
Normal file
23
pkgs/clan-cli/tests/hosts.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import os
|
||||||
|
import pwd
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from clan_cli.ssh.host import Host
|
||||||
|
from clan_cli.ssh.host_key import HostKeyCheck
|
||||||
|
from sshd import Sshd
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hosts(sshd: Sshd) -> list[Host]:
|
||||||
|
login = pwd.getpwuid(os.getuid()).pw_name
|
||||||
|
group = [
|
||||||
|
Host(
|
||||||
|
"127.0.0.1",
|
||||||
|
port=sshd.port,
|
||||||
|
user=login,
|
||||||
|
key=sshd.key,
|
||||||
|
host_key_check=HostKeyCheck.NONE,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return group
|
||||||
7
pkgs/clan-cli/tests/runtime.py
Normal file
7
pkgs/clan-cli/tests/runtime.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import pytest
|
||||||
|
from clan_cli.async_run import AsyncRuntime
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runtime() -> AsyncRuntime:
|
||||||
|
return AsyncRuntime()
|
||||||
@@ -120,6 +120,7 @@ def test_all_dataclasses() -> None:
|
|||||||
excludes = [
|
excludes = [
|
||||||
"api/__init__.py",
|
"api/__init__.py",
|
||||||
"cmd.py", # We don't want the UI to have access to the cmd module anyway
|
"cmd.py", # We don't want the UI to have access to the cmd module anyway
|
||||||
|
"async_run.py", # We don't want the UI to have access to the async_run module anyway
|
||||||
]
|
]
|
||||||
|
|
||||||
cli_path = Path("clan_cli").resolve()
|
cli_path = Path("clan_cli").resolve()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from clan_cli.facts.secret_modules.password_store import SecretStore
|
|||||||
from clan_cli.machines.facts import machine_get_fact
|
from clan_cli.machines.facts import machine_get_fact
|
||||||
from clan_cli.machines.machines import Machine
|
from clan_cli.machines.machines import Machine
|
||||||
from clan_cli.nix import nix_shell
|
from clan_cli.nix import nix_shell
|
||||||
from clan_cli.ssh.host_group import HostGroup
|
from clan_cli.ssh.host import Host
|
||||||
from fixtures_flakes import ClanFlake
|
from fixtures_flakes import ClanFlake
|
||||||
from helpers import cli
|
from helpers import cli
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ def test_upload_secret(
|
|||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
flake: ClanFlake,
|
flake: ClanFlake,
|
||||||
temporary_home: Path,
|
temporary_home: Path,
|
||||||
host_group: HostGroup,
|
hosts: list[Host],
|
||||||
) -> None:
|
) -> None:
|
||||||
flake.clan_modules = [
|
flake.clan_modules = [
|
||||||
"root-password",
|
"root-password",
|
||||||
@@ -27,7 +27,7 @@ def test_upload_secret(
|
|||||||
config = flake.machines["vm1"]
|
config = flake.machines["vm1"]
|
||||||
config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
|
config["nixpkgs"]["hostPlatform"] = "x86_64-linux"
|
||||||
config["clan"]["core"]["networking"]["zerotier"]["controller"]["enable"] = True
|
config["clan"]["core"]["networking"]["zerotier"]["controller"]["enable"] = True
|
||||||
host = host_group.hosts[0]
|
host = hosts[0]
|
||||||
addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}"
|
addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}"
|
||||||
config["clan"]["core"]["networking"]["targetHost"] = addr
|
config["clan"]["core"]["networking"]["targetHost"] = addr
|
||||||
config["clan"]["user-password"]["user"] = "alice"
|
config["clan"]["user-password"]["user"] = "alice"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from clan_cli.ssh.host_group import HostGroup
|
from clan_cli.ssh.host import Host
|
||||||
from fixtures_flakes import FlakeForTest
|
from fixtures_flakes import FlakeForTest
|
||||||
from helpers import cli
|
from helpers import cli
|
||||||
|
|
||||||
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||||||
def test_secrets_upload(
|
def test_secrets_upload(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
test_flake_with_core: FlakeForTest,
|
test_flake_with_core: FlakeForTest,
|
||||||
host_group: HostGroup,
|
hosts: list[Host],
|
||||||
age_keys: list["KeyPair"],
|
age_keys: list["KeyPair"],
|
||||||
) -> None:
|
) -> None:
|
||||||
monkeypatch.chdir(test_flake_with_core.path)
|
monkeypatch.chdir(test_flake_with_core.path)
|
||||||
@@ -48,7 +48,7 @@ def test_secrets_upload(
|
|||||||
)
|
)
|
||||||
|
|
||||||
flake = test_flake_with_core.path.joinpath("flake.nix")
|
flake = test_flake_with_core.path.joinpath("flake.nix")
|
||||||
host = host_group.hosts[0]
|
host = hosts[0]
|
||||||
addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}"
|
addr = f"{host.user}@{host.host}:{host.port}?StrictHostKeyChecking=no&UserKnownHostsFile=/dev/null&IdentityFile={host.key}"
|
||||||
new_text = flake.read_text().replace("__CLAN_TARGET_ADDRESS__", addr)
|
new_text = flake.read_text().replace("__CLAN_TARGET_ADDRESS__", addr)
|
||||||
|
|
||||||
|
|||||||
@@ -1,70 +1,49 @@
|
|||||||
from clan_cli.cmd import Log, RunOpts
|
from clan_cli.async_run import AsyncRuntime
|
||||||
|
from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts
|
||||||
from clan_cli.ssh.host import Host
|
from clan_cli.ssh.host import Host
|
||||||
from clan_cli.ssh.host_group import HostGroup
|
|
||||||
|
|
||||||
hosts = HostGroup([Host("some_host")])
|
host = Host("some_host")
|
||||||
|
|
||||||
|
|
||||||
def test_run_environment() -> None:
|
def test_run_environment(runtime: AsyncRuntime) -> None:
|
||||||
p2 = hosts.run_local(
|
p2 = runtime.async_run(
|
||||||
|
None,
|
||||||
|
host.run_local,
|
||||||
["echo $env_var"],
|
["echo $env_var"],
|
||||||
RunOpts(shell=True, log=Log.STDERR),
|
RunOpts(shell=True, log=Log.STDERR),
|
||||||
extra_env={"env_var": "true"},
|
extra_env={"env_var": "true"},
|
||||||
)
|
)
|
||||||
assert p2[0].result.stdout == "true\n"
|
|
||||||
|
|
||||||
p3 = hosts.run_local(
|
assert p2.wait().result.stdout == "true\n"
|
||||||
["env"], RunOpts(shell=True, log=Log.STDERR), extra_env={"env_var": "true"}
|
|
||||||
|
p3 = runtime.async_run(
|
||||||
|
None,
|
||||||
|
host.run_local,
|
||||||
|
["env"],
|
||||||
|
RunOpts(shell=True, log=Log.STDERR),
|
||||||
|
extra_env={"env_var": "true"},
|
||||||
)
|
)
|
||||||
assert "env_var=true" in p3[0].result.stdout
|
assert "env_var=true" in p3.wait().result.stdout
|
||||||
|
|
||||||
|
|
||||||
def test_run_local() -> None:
|
def test_run_local(runtime: AsyncRuntime) -> None:
|
||||||
hosts.run_local(["echo", "hello"])
|
p1 = runtime.async_run(
|
||||||
|
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
|
||||||
|
)
|
||||||
|
assert p1.wait().result.stdout == "hello\n"
|
||||||
|
|
||||||
|
|
||||||
def test_timeout() -> None:
|
def test_timeout(runtime: AsyncRuntime) -> None:
|
||||||
try:
|
p1 = runtime.async_run(None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01))
|
||||||
hosts.run_local(["sleep", "10"], RunOpts(timeout=0.01))
|
error = p1.wait().error
|
||||||
except Exception:
|
assert isinstance(error, ClanCmdTimeoutError)
|
||||||
pass
|
|
||||||
else:
|
|
||||||
msg = "should have raised TimeoutExpired"
|
|
||||||
raise AssertionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_function() -> None:
|
def test_run_exception(runtime: AsyncRuntime) -> None:
|
||||||
def some_func(h: Host) -> bool:
|
p1 = runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True))
|
||||||
par = h.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
|
assert p1.wait().error is not None
|
||||||
return par.stdout == "hello\n"
|
|
||||||
|
|
||||||
res = hosts.run_function(some_func)
|
|
||||||
assert res[0].result
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_exception() -> None:
|
def test_run_local_non_shell(runtime: AsyncRuntime) -> None:
|
||||||
try:
|
p2 = runtime.async_run(None, host.run_local, ["echo", "1"], RunOpts(log=Log.STDERR))
|
||||||
hosts.run_local(["exit 1"], RunOpts(shell=True))
|
assert p2.wait().result.stdout == "1\n"
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
msg = "should have raised Exception"
|
|
||||||
raise AssertionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_function_exception() -> None:
|
|
||||||
def some_func(h: Host) -> None:
|
|
||||||
h.run_local(["exit 1"], RunOpts(shell=True))
|
|
||||||
|
|
||||||
try:
|
|
||||||
hosts.run_function(some_func)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
msg = "should have raised Exception"
|
|
||||||
raise AssertionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_local_non_shell() -> None:
|
|
||||||
p2 = hosts.run_local(["echo", "1"], RunOpts(log=Log.STDERR))
|
|
||||||
assert p2[0].result.stdout == "1\n"
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from clan_cli.cmd import Log, RunOpts
|
from clan_cli.async_run import AsyncRuntime
|
||||||
|
from clan_cli.cmd import ClanCmdTimeoutError, Log, RunOpts
|
||||||
from clan_cli.errors import ClanError, CmdOut
|
from clan_cli.errors import ClanError, CmdOut
|
||||||
from clan_cli.ssh.host import Host
|
from clan_cli.ssh.host import Host
|
||||||
from clan_cli.ssh.host_group import HostGroup
|
|
||||||
from clan_cli.ssh.host_key import HostKeyCheck
|
from clan_cli.ssh.host_key import HostKeyCheck
|
||||||
from clan_cli.ssh.parse import parse_deployment_address
|
from clan_cli.ssh.parse import parse_deployment_address
|
||||||
|
|
||||||
@@ -20,52 +20,75 @@ def test_parse_ipv6() -> None:
|
|||||||
host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT)
|
host = parse_deployment_address("foo", "fe80::1%eth0", HostKeyCheck.STRICT)
|
||||||
|
|
||||||
|
|
||||||
def test_run(host_group: HostGroup) -> None:
|
def test_run(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||||
proc = host_group.run_local(["echo", "hello"], RunOpts(log=Log.STDERR))
|
for host in hosts:
|
||||||
assert proc[0].result.stdout == "hello\n"
|
proc = runtime.async_run(
|
||||||
|
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
|
||||||
|
)
|
||||||
|
assert proc.wait().result.stdout == "hello\n"
|
||||||
|
|
||||||
|
|
||||||
def test_run_environment(host_group: HostGroup) -> None:
|
def test_run_environment(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||||
p1 = host_group.run(
|
for host in hosts:
|
||||||
["echo $env_var"],
|
proc = runtime.async_run(
|
||||||
RunOpts(shell=True, log=Log.STDERR),
|
None,
|
||||||
extra_env={"env_var": "true"},
|
host.run_local,
|
||||||
)
|
["echo $env_var"],
|
||||||
assert p1[0].result.stdout == "true\n"
|
RunOpts(shell=True, log=Log.STDERR),
|
||||||
p2 = host_group.run(["env"], RunOpts(log=Log.STDERR), extra_env={"env_var": "true"})
|
extra_env={"env_var": "true"},
|
||||||
assert "env_var=true" in p2[0].result.stdout
|
)
|
||||||
|
assert proc.wait().result.stdout == "true\n"
|
||||||
|
|
||||||
|
for host in hosts:
|
||||||
|
p2 = runtime.async_run(
|
||||||
|
None,
|
||||||
|
host.run_local,
|
||||||
|
["env"],
|
||||||
|
RunOpts(log=Log.STDERR),
|
||||||
|
extra_env={"env_var": "true"},
|
||||||
|
)
|
||||||
|
assert "env_var=true" in p2.wait().result.stdout
|
||||||
|
|
||||||
|
|
||||||
def test_run_no_shell(host_group: HostGroup) -> None:
|
def test_run_no_shell(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||||
proc = host_group.run(["echo", "$hello"], RunOpts(log=Log.STDERR))
|
for host in hosts:
|
||||||
assert proc[0].result.stdout == "$hello\n"
|
proc = runtime.async_run(
|
||||||
|
None, host.run_local, ["echo", "hello"], RunOpts(log=Log.STDERR)
|
||||||
|
)
|
||||||
|
assert proc.wait().result.stdout == "hello\n"
|
||||||
|
|
||||||
|
|
||||||
def test_run_function(host_group: HostGroup) -> None:
|
def test_run_function(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||||
def some_func(h: Host) -> bool:
|
def some_func(h: Host) -> bool:
|
||||||
p = h.run(["echo", "hello"])
|
p = h.run(["echo", "hello"])
|
||||||
return p.stdout == "hello\n"
|
return p.stdout == "hello\n"
|
||||||
|
|
||||||
res = host_group.run_function(some_func)
|
for host in hosts:
|
||||||
assert res[0].result
|
proc = runtime.async_run(None, some_func, host)
|
||||||
|
assert proc.wait().result
|
||||||
|
|
||||||
|
|
||||||
def test_timeout(host_group: HostGroup) -> None:
|
def test_timeout(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||||
try:
|
for host in hosts:
|
||||||
host_group.run_local(["sleep", "10"], RunOpts(timeout=0.01))
|
proc = runtime.async_run(
|
||||||
except Exception:
|
None, host.run_local, ["sleep", "10"], RunOpts(timeout=0.01)
|
||||||
pass
|
)
|
||||||
else:
|
error = proc.wait().error
|
||||||
msg = "should have raised TimeoutExpired"
|
assert isinstance(error, ClanCmdTimeoutError)
|
||||||
raise AssertionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_run_exception(host_group: HostGroup) -> None:
|
def test_run_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||||
r = host_group.run(["exit 1"], RunOpts(check=False, shell=True))
|
for host in hosts:
|
||||||
assert r[0].result.returncode == 1
|
proc = runtime.async_run(
|
||||||
|
None, host.run_local, ["exit 1"], RunOpts(shell=True, check=False)
|
||||||
|
)
|
||||||
|
assert proc.wait().result.returncode == 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
host_group.run(["exit 1"], RunOpts(shell=True))
|
for host in hosts:
|
||||||
|
runtime.async_run(None, host.run_local, ["exit 1"], RunOpts(shell=True))
|
||||||
|
runtime.join_all()
|
||||||
|
runtime.check_all()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@@ -73,12 +96,15 @@ def test_run_exception(host_group: HostGroup) -> None:
|
|||||||
raise AssertionError(msg)
|
raise AssertionError(msg)
|
||||||
|
|
||||||
|
|
||||||
def test_run_function_exception(host_group: HostGroup) -> None:
|
def test_run_function_exception(hosts: list[Host], runtime: AsyncRuntime) -> None:
|
||||||
def some_func(h: Host) -> CmdOut:
|
def some_func(h: Host) -> CmdOut:
|
||||||
return h.run_local(["exit 1"], RunOpts(shell=True))
|
return h.run_local(["exit 1"], RunOpts(shell=True))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
host_group.run_function(some_func)
|
for host in hosts:
|
||||||
|
runtime.async_run(None, some_func, host)
|
||||||
|
runtime.join_all()
|
||||||
|
runtime.check_all()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user