286 lines
9.2 KiB
Python
286 lines
9.2 KiB
Python
import argparse
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Literal, TypeVar, get_args
|
|
|
|
from clan_lib.async_run import AsyncContext, AsyncFuture, AsyncOpts, AsyncRuntime
|
|
from clan_lib.errors import ClanError, text_heading
|
|
from clan_lib.flake import require_flake
|
|
from clan_lib.machines.generations import MachineGeneration, get_machine_generations
|
|
from clan_lib.machines.machines import Machine
|
|
from clan_lib.metrics.telegraf import MonitoringNotEnabledError
|
|
from clan_lib.metrics.version import check_machine_up_to_date
|
|
from clan_lib.network.network import get_best_remote
|
|
from clan_lib.ssh.host_key import HostKeyCheck
|
|
from clan_lib.ssh.localhost import LocalHost
|
|
from clan_lib.ssh.remote import Remote
|
|
|
|
from clan_cli.completions import (
|
|
add_dynamic_completer,
|
|
complete_machines,
|
|
complete_tags,
|
|
)
|
|
from clan_cli.machines.update import get_machines_for_update
|
|
|
|
if TYPE_CHECKING:
|
|
from clan_lib.ssh.host import Host
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
UpToDateType = Literal["up-to-date", "out-of-date", "unknown"]
|
|
|
|
|
|
def print_generations(
|
|
generations: list[MachineGeneration],
|
|
needs_update: UpToDateType = "unknown",
|
|
) -> None:
|
|
headers = [
|
|
"Generation (Up-To-Date)",
|
|
"Date",
|
|
"NixOS Version",
|
|
"Kernel Version",
|
|
]
|
|
rows = []
|
|
for gen in generations:
|
|
gen_marker = f" ← ({needs_update})" if gen.current else ""
|
|
gen_str = f"{gen.generation}{gen_marker}"
|
|
row = [
|
|
gen_str,
|
|
gen.date,
|
|
gen.nixos_version,
|
|
gen.kernel_version,
|
|
]
|
|
rows.append(row)
|
|
|
|
elided_rows = rows
|
|
|
|
col_widths = [
|
|
max(len(str(item)) for item in [header] + [row[i] for row in elided_rows])
|
|
for i, header in enumerate(headers)
|
|
]
|
|
|
|
# Print header
|
|
header_row = " | ".join(
|
|
header.ljust(col_widths[i]) for i, header in enumerate(headers)
|
|
)
|
|
print(header_row)
|
|
print("-+-".join("-" * w for w in col_widths))
|
|
|
|
# Print rows
|
|
for row in elided_rows:
|
|
print(" | ".join(row[i].ljust(col_widths[i]) for i in range(len(headers))))
|
|
|
|
print()
|
|
|
|
|
|
def print_summary_table(
|
|
machine_data: dict[Machine, tuple[list[MachineGeneration], UpToDateType]],
|
|
) -> None:
|
|
print(text_heading("Current Generations Summary"))
|
|
headers = ["Machine", "Current Generation", "Date", "NixOS Version", "Up-To-Date"]
|
|
rows = []
|
|
|
|
for machine, (generations, needs_update) in machine_data.items():
|
|
current_gen = None
|
|
for gen in generations:
|
|
if gen.current:
|
|
current_gen = gen
|
|
break
|
|
|
|
if current_gen is None:
|
|
continue
|
|
|
|
status = needs_update
|
|
row = [
|
|
machine.name,
|
|
str(current_gen.generation),
|
|
current_gen.date,
|
|
current_gen.nixos_version,
|
|
status,
|
|
]
|
|
rows.append(row)
|
|
|
|
if not rows:
|
|
print("Couldn't retrieve data from any machine.")
|
|
return
|
|
|
|
col_widths = [
|
|
max(len(str(item)) for item in [header] + [row[i] for row in rows])
|
|
for i, header in enumerate(headers)
|
|
]
|
|
|
|
# Print header
|
|
header_row = " | ".join(
|
|
header.ljust(col_widths[i]) for i, header in enumerate(headers)
|
|
)
|
|
print(header_row)
|
|
print("-+-".join("-" * w for w in col_widths))
|
|
|
|
# Print rows
|
|
for row in rows:
|
|
print(" | ".join(row[i].ljust(col_widths[i]) for i in range(len(headers))))
|
|
|
|
print()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MachineVersionData:
|
|
generations: AsyncFuture[list[MachineGeneration]]
|
|
machine_update: AsyncFuture[bool] | None
|
|
|
|
|
|
def generations_command(args: argparse.Namespace) -> None:
|
|
flake = require_flake(args.flake)
|
|
|
|
machines_to_update = get_machines_for_update(flake, args.machines, args.tags)
|
|
|
|
if args.target_host is not None and len(machines_to_update) > 1:
|
|
msg = "Target Host can only be set for one machines"
|
|
raise ClanError(msg)
|
|
|
|
host_key_check = args.host_key_check
|
|
machine_generations: dict[Machine, MachineVersionData] = {}
|
|
with AsyncRuntime() as runtime:
|
|
for machine in machines_to_update:
|
|
if args.target_host:
|
|
target_host: Host | None = None
|
|
if args.target_host == "localhost":
|
|
target_host = LocalHost()
|
|
else:
|
|
target_host = Remote.from_ssh_uri(
|
|
machine_name=machine.name,
|
|
address=args.target_host,
|
|
).override(host_key_check=host_key_check)
|
|
else:
|
|
try:
|
|
with get_best_remote(machine) as _remote:
|
|
target_host = machine.target_host().override(
|
|
host_key_check=host_key_check
|
|
)
|
|
except ClanError:
|
|
log.warning(
|
|
f"Skipping {machine.name} as it has no target host configured."
|
|
)
|
|
continue
|
|
generations = runtime.async_run(
|
|
AsyncOpts(
|
|
tid=machine.name,
|
|
async_ctx=AsyncContext(prefix=machine.name),
|
|
),
|
|
get_machine_generations,
|
|
target_host=target_host,
|
|
)
|
|
if args.skip_outdated_check:
|
|
machine_update = None
|
|
else:
|
|
machine_update = runtime.async_run(
|
|
AsyncOpts(
|
|
tid=machine.name + "-needs-update",
|
|
async_ctx=AsyncContext(prefix=machine.name),
|
|
),
|
|
check_machine_up_to_date,
|
|
machine=machine,
|
|
target_host=target_host,
|
|
)
|
|
machine_generations[machine] = MachineVersionData(
|
|
generations, machine_update
|
|
)
|
|
runtime.join_all()
|
|
|
|
R = TypeVar("R")
|
|
|
|
errors: dict[Machine, Exception] = {}
|
|
successful_machines: dict[
|
|
Machine, tuple[list[MachineGeneration], UpToDateType]
|
|
] = {}
|
|
|
|
for machine, async_version_data in machine_generations.items():
|
|
|
|
def get_result(async_future: AsyncFuture[R]) -> R | Exception:
|
|
aresult = async_future.get_result()
|
|
if aresult is None:
|
|
msg = "Generations result should never be None"
|
|
raise ClanError(msg)
|
|
if aresult.error is not None:
|
|
return aresult.error
|
|
return aresult.result
|
|
|
|
mgenerations = get_result(async_version_data.generations)
|
|
if isinstance(mgenerations, Exception):
|
|
errors[machine] = mgenerations
|
|
continue
|
|
|
|
if async_version_data.machine_update is None:
|
|
needs_update: UpToDateType = "unknown"
|
|
else:
|
|
eneeds_update = get_result(async_version_data.machine_update)
|
|
if isinstance(eneeds_update, MonitoringNotEnabledError):
|
|
log.warning(
|
|
f"Skipping up-to-date check for {machine.name} as monitoring is not enabled."
|
|
)
|
|
needs_update = "unknown"
|
|
elif isinstance(eneeds_update, Exception):
|
|
errors[machine] = eneeds_update
|
|
continue
|
|
else:
|
|
needs_update = "out-of-date" if eneeds_update else "up-to-date"
|
|
|
|
successful_machines[machine] = (mgenerations, needs_update)
|
|
|
|
# Check if specific machines were requested
|
|
specific_machines_requested = bool(args.machines or args.tags)
|
|
|
|
if specific_machines_requested:
|
|
# Print detailed generations for each machine
|
|
for mgenerations, needs_update in successful_machines.values():
|
|
print_generations(
|
|
generations=mgenerations,
|
|
needs_update=needs_update,
|
|
)
|
|
else:
|
|
# Print summary table
|
|
print_summary_table(successful_machines)
|
|
|
|
for machine, error in errors.items():
|
|
msg = f"Failed for machine {machine.name}: {error}"
|
|
raise ClanError(msg) from error
|
|
|
|
|
|
def register_generations_parser(parser: argparse.ArgumentParser) -> None:
|
|
machines_parser = parser.add_argument(
|
|
"machines",
|
|
type=str,
|
|
nargs="*",
|
|
default=[],
|
|
metavar="MACHINE",
|
|
help="Machine to update. If no machines are specified, all machines that don't require explicit updates will be updated.",
|
|
)
|
|
add_dynamic_completer(machines_parser, complete_machines)
|
|
|
|
tag_parser = parser.add_argument(
|
|
"--tags",
|
|
nargs="+",
|
|
default=[],
|
|
help="Tags that machines should be queried for. Multiple tags will intersect.",
|
|
)
|
|
add_dynamic_completer(tag_parser, complete_tags)
|
|
|
|
parser.add_argument(
|
|
"--host-key-check",
|
|
choices=list(get_args(HostKeyCheck)),
|
|
default="ask",
|
|
help="Host key (.ssh/known_hosts) check mode.",
|
|
)
|
|
parser.add_argument(
|
|
"--target-host",
|
|
type=str,
|
|
help="Address of the machine to update, in the format of user@host:1234.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--skip-outdated-check",
|
|
action="store_true",
|
|
help="Skip checking if the current generation is outdated (faster).",
|
|
)
|
|
parser.set_defaults(func=generations_command)
|