clan-cli: Make 'clan ssh' read out the targetHost to connect to

This commit is contained in:
Qubasa
2025-07-14 19:35:48 +07:00
parent 1c2b72c6f0
commit 9630b6dbe4
4 changed files with 170 additions and 43 deletions

View File

@@ -1,12 +1,14 @@
import argparse
import json
import logging
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from clan_lib.cmd import run
from clan_lib.errors import ClanError
from clan_lib.machines.machines import Machine
from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import HostKeyCheck, Remote
@@ -37,20 +39,23 @@ class DeployInfo:
raise ClanError(msg)
return addrs[0]
@staticmethod
def from_hostnames(
hostname: list[str], host_key_check: HostKeyCheck
def overwrite_remotes(
self,
host_key_check: HostKeyCheck | None = None,
private_key: Path | None = None,
ssh_options: dict[str, str] | None = None,
) -> "DeployInfo":
remotes = []
for host in hostname:
if not host:
msg = "Hostname cannot be empty."
raise ClanError(msg)
remote = Remote.from_ssh_uri(
machine_name="clan-installer", address=host
).override(host_key_check=host_key_check)
remotes.append(remote)
return DeployInfo(addrs=remotes)
"""Return a new DeployInfo with all Remotes overridden with the given host_key_check."""
return DeployInfo(
addrs=[
addr.override(
host_key_check=host_key_check,
private_key=private_key,
ssh_options=ssh_options,
)
for addr in self.addrs
]
)
@staticmethod
def from_json(data: dict[str, Any], host_key_check: HostKeyCheck) -> "DeployInfo":
@@ -103,9 +108,22 @@ def find_reachable_host(deploy_info: DeployInfo) -> Remote | None:
return None
def ssh_shell_from_deploy(deploy_info: DeployInfo) -> None:
def ssh_shell_from_deploy(
deploy_info: DeployInfo, command: list[str] | None = None
) -> None:
if command and len(command) == 1 and command[0].count(" ") > 0:
msg = (
textwrap.dedent("""
It looks like you quoted the remote command.
The first argument should be the command to run, not a quoted string.
""")
.lstrip("\n")
.rstrip("\n")
)
raise ClanError(msg)
if host := find_reachable_host(deploy_info):
host.interactive_ssh()
host.interactive_ssh(command)
return
log.info("Could not reach host via clearnet 'addrs'")
@@ -127,7 +145,7 @@ def ssh_shell_from_deploy(deploy_info: DeployInfo) -> None:
log.info(
"Host reachable via tor address, starting interactive ssh session."
)
tor_addr.interactive_ssh()
tor_addr.interactive_ssh(command)
return
log.error("Could not reach host via tor address.")
@@ -135,19 +153,32 @@ def ssh_shell_from_deploy(deploy_info: DeployInfo) -> None:
def ssh_command_parse(args: argparse.Namespace) -> DeployInfo | None:
host_key_check = args.host_key_check
deploy = None
if args.json:
json_file = Path(args.json)
if json_file.is_file():
data = json.loads(json_file.read_text())
return DeployInfo.from_json(data, host_key_check)
data = json.loads(args.json)
return DeployInfo.from_json(data, host_key_check)
deploy = DeployInfo.from_json(data, host_key_check)
if args.png:
return DeployInfo.from_qr_code(Path(args.png), host_key_check)
deploy = DeployInfo.from_qr_code(Path(args.png), host_key_check)
if hasattr(args, "machines"):
return DeployInfo.from_hostnames(args.machines, host_key_check)
return None
if hasattr(args, "machine") and args.machine:
machine = Machine(args.machine, args.flake)
target = machine.target_host().override(
command_prefix=machine.name, host_key_check=host_key_check
)
deploy = DeployInfo(addrs=[target])
if deploy is None:
return None
ssh_options = {}
for name, value in args.ssh_option or []:
ssh_options[name] = value
deploy = deploy.overwrite_remotes(ssh_options=ssh_options)
return deploy
def ssh_command(args: argparse.Namespace) -> None:
@@ -155,36 +186,63 @@ def ssh_command(args: argparse.Namespace) -> None:
if not deploy_info:
msg = "No MACHINE, --json or --png data provided"
raise ClanError(msg)
ssh_shell_from_deploy(deploy_info)
ssh_shell_from_deploy(deploy_info, args.remote_command)
def register_parser(parser: argparse.ArgumentParser) -> None:
group = parser.add_mutually_exclusive_group(required=True)
machines_parser = group.add_argument(
"machines",
group = parser.add_mutually_exclusive_group()
group.add_argument(
"machine",
type=str,
nargs="*",
default=[],
nargs="?",
metavar="MACHINE",
help="Machine to ssh into.",
help="Machine to ssh into (uses clan.core.networking.targetHost from configuration).",
)
add_dynamic_completer(machines_parser, complete_machines)
group.add_argument(
"-j",
"--json",
help="specify the json file for ssh data (generated by starting the clan installer)",
type=str,
help=(
"Deployment information as a JSON string or path to a JSON file "
"(generated by starting the clan installer)."
),
)
group.add_argument(
"-P",
"--png",
help="specify the json file for ssh data as the qrcode image (generated by starting the clan installer)",
type=str,
help="Deployment information as a QR code image file (generated by starting the clan installer).",
)
parser.add_argument(
"--host-key-check",
choices=["strict", "ask", "tofu", "none"],
default="tofu",
help="Host key (.ssh/known_hosts) check mode.",
)
parser.add_argument(
"--ssh-option",
help="SSH option to set (can be specified multiple times)",
nargs=2,
metavar=("name", "value"),
action="append",
default=[],
)
parser.add_argument(
"-c",
"--remote-command",
type=str,
metavar="COMMAND",
nargs=argparse.REMAINDER,
help="Command to execute on the remote host, needs to be the LAST argument as it takes all remaining arguments.",
)
add_dynamic_completer(
parser._actions[1], # noqa: SLF001
complete_machines,
) # assumes 'machine' is the first positional
parser.set_defaults(func=ssh_command)

View File

@@ -7,6 +7,10 @@ from clan_lib.nix import nix_shell
from clan_lib.ssh.remote import Remote
from clan_cli.ssh.deploy_info import DeployInfo, find_reachable_host
from clan_cli.tests.fixtures_flakes import ClanFlake
from clan_cli.tests.helpers import cli
from clan_cli.tests.nix_config import ConfigItem
from clan_cli.tests.stdout import CaptureOutput
def test_qrcode_scan(temp_dir: Path) -> None:
@@ -69,7 +73,10 @@ def test_from_json() -> None:
@pytest.mark.with_core
def test_find_reachable_host(hosts: list[Remote]) -> None:
host = hosts[0]
deploy_info = DeployInfo.from_hostnames(["172.19.1.2", host.ssh_url()], "none")
uris = ["172.19.1.2", host.ssh_url()]
remotes = [Remote.from_ssh_uri(machine_name="some", address=uri) for uri in uris]
deploy_info = DeployInfo(addrs=remotes)
assert deploy_info.addrs[0].address == "172.19.1.2"
@@ -77,3 +84,42 @@ def test_find_reachable_host(hosts: list[Remote]) -> None:
assert remote is not None
assert remote.ssh_url() == host.ssh_url()
@pytest.mark.with_core
def test_ssh_shell_from_deploy(
hosts: list[Remote],
flake: ClanFlake,
nix_config: dict[str, ConfigItem],
capture_output: CaptureOutput,
) -> None:
host = hosts[0]
machine1_config = flake.machines["m1_machine"]
machine1_config["nixpkgs"]["hostPlatform"] = nix_config["system"].value
machine1_config["clan"]["networking"]["targetHost"] = host.ssh_url()
flake.refresh()
assert host.private_key
success_txt = flake.path / "success.txt"
assert not success_txt.exists()
cli.run(
[
"ssh",
"--flake",
str(flake.path),
"m1_machine",
"--host-key-check=none",
"--ssh-option",
"IdentityFile",
str(host.private_key),
"--remote-command",
"touch",
str(success_txt),
"&&",
"exit 0",
]
)
assert success_txt.exists()