clan-cli network: refactor, use new networking in ssh and install commands
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from typing import get_args
|
||||
|
||||
@@ -8,6 +10,7 @@ from clan_lib.errors import ClanError
|
||||
from clan_lib.flake import require_flake
|
||||
from clan_lib.machines.install import BuildOn, InstallOptions, run_machine_install
|
||||
from clan_lib.machines.machines import Machine
|
||||
from clan_lib.network.qr_code import read_qr_image, read_qr_json
|
||||
from clan_lib.ssh.host_key import HostKeyCheck
|
||||
from clan_lib.ssh.remote import Remote
|
||||
|
||||
@@ -17,11 +20,6 @@ from clan_cli.completions import (
|
||||
complete_target_host,
|
||||
)
|
||||
from clan_cli.machines.hardware import HardwareConfig
|
||||
from clan_cli.ssh.deploy_info import (
|
||||
find_reachable_host,
|
||||
get_tor_remote,
|
||||
ssh_command_parse,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,81 +29,71 @@ def install_command(args: argparse.Namespace) -> None:
|
||||
flake = require_flake(args.flake)
|
||||
# Only if the caller did not specify a target_host via args.target_host
|
||||
# Find a suitable target_host that is reachable
|
||||
target_host_str = args.target_host
|
||||
remotes: list[Remote] | None = (
|
||||
ssh_command_parse(args) if target_host_str is None else None
|
||||
)
|
||||
|
||||
use_tor = False
|
||||
if remotes:
|
||||
host = find_reachable_host(remotes)
|
||||
if host is None or host.socks_port:
|
||||
use_tor = True
|
||||
tor_remote = get_tor_remote(remotes)
|
||||
target_host_str = tor_remote.target
|
||||
else:
|
||||
target_host_str = host.target
|
||||
|
||||
if args.password:
|
||||
password = args.password
|
||||
elif remotes and remotes[0].password:
|
||||
password = remotes[0].password
|
||||
else:
|
||||
password = None
|
||||
|
||||
machine = Machine(name=args.machine, flake=flake)
|
||||
host_key_check = args.host_key_check
|
||||
|
||||
if target_host_str is not None:
|
||||
target_host = Remote.from_ssh_uri(
|
||||
machine_name=machine.name, address=target_host_str
|
||||
).override(host_key_check=host_key_check)
|
||||
else:
|
||||
target_host = machine.target_host().override(host_key_check=host_key_check)
|
||||
|
||||
if args.identity_file:
|
||||
target_host = target_host.override(private_key=args.identity_file)
|
||||
|
||||
if machine._class_ == "darwin":
|
||||
msg = "Installing macOS machines is not yet supported"
|
||||
raise ClanError(msg)
|
||||
|
||||
if not args.yes:
|
||||
while True:
|
||||
ask = (
|
||||
input(f"Install {args.machine} to {target_host.target}? [y/N] ")
|
||||
.strip()
|
||||
.lower()
|
||||
with ExitStack() as stack:
|
||||
remote: Remote
|
||||
if args.target_host:
|
||||
# TODO add network support here with either --network or some url magic
|
||||
remote = Remote.from_ssh_uri(
|
||||
machine_name=args.machine, address=args.target_host
|
||||
)
|
||||
if ask == "y":
|
||||
break
|
||||
if ask == "n" or ask == "":
|
||||
return None
|
||||
print(f"Invalid input '{ask}'. Please enter 'y' for yes or 'n' for no.")
|
||||
elif args.png:
|
||||
data = read_qr_image(Path(args.png))
|
||||
qr_code = read_qr_json(data, args.flake)
|
||||
remote = stack.enter_context(qr_code.get_best_remote())
|
||||
elif args.json:
|
||||
json_file = Path(args.json)
|
||||
if json_file.is_file():
|
||||
data = json.loads(json_file.read_text())
|
||||
else:
|
||||
data = json.loads(args.json)
|
||||
|
||||
if args.identity_file:
|
||||
target_host = target_host.override(private_key=args.identity_file)
|
||||
qr_code = read_qr_json(data, args.flake)
|
||||
remote = stack.enter_context(qr_code.get_best_remote())
|
||||
else:
|
||||
msg = "No MACHINE, --json or --png data provided"
|
||||
raise ClanError(msg)
|
||||
|
||||
if password:
|
||||
target_host = target_host.override(password=password)
|
||||
machine = Machine(name=args.machine, flake=flake)
|
||||
if args.host_key_check:
|
||||
remote.override(host_key_check=args.host_key_check)
|
||||
|
||||
if use_tor:
|
||||
target_host = target_host.override(
|
||||
socks_port=9050, socks_wrapper=["torify"]
|
||||
if machine._class_ == "darwin":
|
||||
msg = "Installing macOS machines is not yet supported"
|
||||
raise ClanError(msg)
|
||||
|
||||
if not args.yes:
|
||||
while True:
|
||||
ask = (
|
||||
input(f"Install {args.machine} to {remote.target}? [y/N] ")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if ask == "y":
|
||||
break
|
||||
if ask == "n" or ask == "":
|
||||
return None
|
||||
print(
|
||||
f"Invalid input '{ask}'. Please enter 'y' for yes or 'n' for no."
|
||||
)
|
||||
|
||||
if args.identity_file:
|
||||
remote = remote.override(private_key=args.identity_file)
|
||||
|
||||
if args.password:
|
||||
remote = remote.override(password=args.password)
|
||||
|
||||
return run_machine_install(
|
||||
InstallOptions(
|
||||
machine=machine,
|
||||
kexec=args.kexec,
|
||||
phases=args.phases,
|
||||
debug=args.debug,
|
||||
no_reboot=args.no_reboot,
|
||||
build_on=args.build_on if args.build_on is not None else None,
|
||||
update_hardware_config=HardwareConfig(args.update_hardware_config),
|
||||
),
|
||||
target_host=remote,
|
||||
)
|
||||
|
||||
return run_machine_install(
|
||||
InstallOptions(
|
||||
machine=machine,
|
||||
kexec=args.kexec,
|
||||
phases=args.phases,
|
||||
debug=args.debug,
|
||||
no_reboot=args.no_reboot,
|
||||
build_on=args.build_on if args.build_on is not None else None,
|
||||
update_hardware_config=HardwareConfig(args.update_hardware_config),
|
||||
),
|
||||
target_host=target_host,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
log.warning("Interrupted by user")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from typing import get_args
|
||||
|
||||
from clan_lib.errors import ClanError
|
||||
from clan_lib.machines.machines import Machine
|
||||
from clan_lib.network.qr_code import parse_qr_image_to_json, parse_qr_json_to_networks
|
||||
from clan_lib.network.tor.lib import spawn_tor
|
||||
from clan_lib.network.network import get_best_remote
|
||||
from clan_lib.network.qr_code import read_qr_image, read_qr_json
|
||||
from clan_lib.ssh.remote import HostKeyCheck, Remote
|
||||
|
||||
from clan_cli.completions import (
|
||||
@@ -35,112 +34,42 @@ def get_tor_remote(remotes: list[Remote]) -> Remote:
|
||||
return tor_remotes[0]
|
||||
|
||||
|
||||
def find_reachable_host(remotes: list[Remote]) -> Remote | None:
|
||||
# If we only have one address, we have no choice but to use it.
|
||||
if len(remotes) == 1:
|
||||
return remotes[0]
|
||||
|
||||
for remote in remotes:
|
||||
with contextlib.suppress(ClanError):
|
||||
remote.check_machine_ssh_reachable()
|
||||
return remote
|
||||
return None
|
||||
|
||||
|
||||
def ssh_shell_from_remotes(
|
||||
remotes: list[Remote], command: list[str] | None = None
|
||||
) -> None:
|
||||
if command and len(command) == 1 and command[0].count(" ") > 0:
|
||||
msg = (
|
||||
textwrap.dedent("""
|
||||
It looks like you quoted the remote command.
|
||||
The first argument should be the command to run, not a quoted string.
|
||||
""")
|
||||
.lstrip("\n")
|
||||
.rstrip("\n")
|
||||
)
|
||||
raise ClanError(msg)
|
||||
|
||||
if host := find_reachable_host(remotes):
|
||||
host.interactive_ssh(command)
|
||||
return
|
||||
|
||||
log.info("Could not reach host via clearnet addresses")
|
||||
log.info("Trying to reach host via tor")
|
||||
|
||||
tor_remotes = [r for r in remotes if r.socks_port]
|
||||
if not tor_remotes:
|
||||
msg = "No tor address provided, please provide a tor address."
|
||||
raise ClanError(msg)
|
||||
|
||||
with spawn_tor():
|
||||
for tor_remote in tor_remotes:
|
||||
log.info(f"Trying to reach host via tor address: {tor_remote}")
|
||||
|
||||
with contextlib.suppress(ClanError):
|
||||
tor_remote.check_machine_ssh_reachable()
|
||||
|
||||
log.info(
|
||||
"Host reachable via tor address, starting interactive ssh session."
|
||||
)
|
||||
tor_remote.interactive_ssh(command)
|
||||
return
|
||||
|
||||
log.error("Could not reach host via tor address.")
|
||||
|
||||
|
||||
def ssh_command_parse(args: argparse.Namespace) -> list[Remote] | None:
|
||||
host_key_check = args.host_key_check
|
||||
remotes = None
|
||||
|
||||
if args.json:
|
||||
json_file = Path(args.json)
|
||||
if json_file.is_file():
|
||||
data = json.loads(json_file.read_text())
|
||||
else:
|
||||
data = json.loads(args.json)
|
||||
|
||||
networks = parse_qr_json_to_networks(data, args.flake)
|
||||
remotes = []
|
||||
for _network_type, network_data in networks.items():
|
||||
remote = network_data["remote"]
|
||||
remotes.append(remote.override(host_key_check=host_key_check))
|
||||
|
||||
elif args.png:
|
||||
data = parse_qr_image_to_json(Path(args.png))
|
||||
networks = parse_qr_json_to_networks(data, args.flake)
|
||||
remotes = []
|
||||
for _network_type, network_data in networks.items():
|
||||
remote = network_data["remote"]
|
||||
remotes.append(remote.override(host_key_check=host_key_check))
|
||||
|
||||
elif hasattr(args, "machine") and args.machine:
|
||||
machine = Machine(args.machine, args.flake)
|
||||
target = machine.target_host().override(
|
||||
command_prefix=machine.name, host_key_check=host_key_check
|
||||
)
|
||||
remotes = [target]
|
||||
else:
|
||||
return None
|
||||
|
||||
ssh_options = None
|
||||
if hasattr(args, "ssh_option") and args.ssh_option:
|
||||
ssh_options = {}
|
||||
for name, value in args.ssh_option:
|
||||
ssh_options[name] = value
|
||||
|
||||
if ssh_options:
|
||||
remotes = [remote.override(ssh_options=ssh_options) for remote in remotes]
|
||||
|
||||
return remotes
|
||||
|
||||
|
||||
def ssh_command(args: argparse.Namespace) -> None:
|
||||
remotes = ssh_command_parse(args)
|
||||
if not remotes:
|
||||
msg = "No MACHINE, --json or --png data provided"
|
||||
raise ClanError(msg)
|
||||
ssh_shell_from_remotes(remotes, args.remote_command)
|
||||
with ExitStack() as stack:
|
||||
remote: Remote
|
||||
if hasattr(args, "machine") and args.machine:
|
||||
machine = Machine(args.machine, args.flake)
|
||||
remote = stack.enter_context(get_best_remote(machine))
|
||||
elif args.png:
|
||||
data = read_qr_image(Path(args.png))
|
||||
qr_code = read_qr_json(data, args.flake)
|
||||
remote = stack.enter_context(qr_code.get_best_remote())
|
||||
elif args.json:
|
||||
json_file = Path(args.json)
|
||||
if json_file.is_file():
|
||||
data = json.loads(json_file.read_text())
|
||||
else:
|
||||
data = json.loads(args.json)
|
||||
|
||||
qr_code = read_qr_json(data, args.flake)
|
||||
remote = stack.enter_context(qr_code.get_best_remote())
|
||||
else:
|
||||
msg = "No MACHINE, --json or --png data provided"
|
||||
raise ClanError(msg)
|
||||
|
||||
# Convert ssh_option list to dictionary
|
||||
ssh_options = {}
|
||||
if args.ssh_option:
|
||||
for name, value in args.ssh_option:
|
||||
ssh_options[name] = value
|
||||
|
||||
remote = remote.override(
|
||||
host_key_check=args.host_key_check, ssh_options=ssh_options
|
||||
)
|
||||
if args.remote_command:
|
||||
remote.interactive_ssh(args.remote_command)
|
||||
else:
|
||||
remote.interactive_ssh()
|
||||
|
||||
|
||||
def register_parser(parser: argparse.ArgumentParser) -> None:
|
||||
|
||||
@@ -4,11 +4,10 @@ from pathlib import Path
|
||||
import pytest
|
||||
from clan_lib.cmd import RunOpts, run
|
||||
from clan_lib.flake import Flake
|
||||
from clan_lib.network.qr_code import parse_qr_image_to_json, parse_qr_json_to_networks
|
||||
from clan_lib.network.qr_code import read_qr_image, read_qr_json
|
||||
from clan_lib.nix import nix_shell
|
||||
from clan_lib.ssh.remote import Remote
|
||||
|
||||
from clan_cli.ssh.deploy_info import find_reachable_host
|
||||
from clan_cli.tests.fixtures_flakes import ClanFlake
|
||||
from clan_cli.tests.helpers import cli
|
||||
|
||||
@@ -28,84 +27,93 @@ def test_qrcode_scan(temp_dir: Path, flake: ClanFlake) -> None:
|
||||
run(cmd, RunOpts(input=data.encode()))
|
||||
|
||||
# Call the qrcode_scan function
|
||||
json_data = parse_qr_image_to_json(img_path)
|
||||
networks = parse_qr_json_to_networks(json_data, Flake(str(flake.path)))
|
||||
json_data = read_qr_image(img_path)
|
||||
qr_code = read_qr_json(json_data, Flake(str(flake.path)))
|
||||
|
||||
# Get direct network data
|
||||
direct_data = networks.get("direct")
|
||||
assert direct_data is not None
|
||||
assert "network" in direct_data
|
||||
assert "remote" in direct_data
|
||||
# Check addresses
|
||||
addresses = qr_code.addresses
|
||||
assert len(addresses) >= 2 # At least direct and tor
|
||||
|
||||
# Get the remote
|
||||
host = direct_data["remote"]
|
||||
assert host.address == "192.168.122.86"
|
||||
assert host.user == "root"
|
||||
assert host.password == "scabbed-defender-headlock"
|
||||
# Find direct connection
|
||||
direct_remote = None
|
||||
for addr in addresses:
|
||||
if addr.network.module_name == "clan_lib.network.direct":
|
||||
direct_remote = addr.remote
|
||||
break
|
||||
|
||||
# Get tor network data
|
||||
tor_data = networks.get("tor")
|
||||
assert tor_data is not None
|
||||
assert "network" in tor_data
|
||||
assert "remote" in tor_data
|
||||
assert direct_remote is not None
|
||||
assert direct_remote.address == "192.168.122.86"
|
||||
assert direct_remote.user == "root"
|
||||
assert direct_remote.password == "scabbed-defender-headlock"
|
||||
|
||||
# Get the remote
|
||||
tor_host = tor_data["remote"]
|
||||
# Find tor connection
|
||||
tor_remote = None
|
||||
for addr in addresses:
|
||||
if addr.network.module_name == "clan_lib.network.tor":
|
||||
tor_remote = addr.remote
|
||||
break
|
||||
|
||||
assert tor_remote is not None
|
||||
assert (
|
||||
tor_host.address
|
||||
tor_remote.address
|
||||
== "qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion"
|
||||
)
|
||||
assert tor_host.socks_port == 9050
|
||||
assert tor_host.password == "scabbed-defender-headlock"
|
||||
assert tor_host.user == "root"
|
||||
assert tor_remote.socks_port == 9050
|
||||
assert tor_remote.password == "scabbed-defender-headlock"
|
||||
assert tor_remote.user == "root"
|
||||
|
||||
|
||||
def test_from_json(temp_dir: Path) -> None:
|
||||
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
|
||||
flake = Flake(str(temp_dir))
|
||||
networks = parse_qr_json_to_networks(json.loads(data), flake)
|
||||
qr_code = read_qr_json(json.loads(data), flake)
|
||||
|
||||
# Get direct network data
|
||||
direct_data = networks.get("direct")
|
||||
assert direct_data is not None
|
||||
assert "network" in direct_data
|
||||
assert "remote" in direct_data
|
||||
# Check addresses
|
||||
addresses = qr_code.addresses
|
||||
assert len(addresses) >= 2 # At least direct and tor
|
||||
|
||||
# Get the remote
|
||||
host = direct_data["remote"]
|
||||
assert host.password == "scabbed-defender-headlock"
|
||||
assert host.address == "192.168.122.86"
|
||||
# Find direct connection
|
||||
direct_remote = None
|
||||
for addr in addresses:
|
||||
if addr.network.module_name == "clan_lib.network.direct":
|
||||
direct_remote = addr.remote
|
||||
break
|
||||
|
||||
# Get tor network data
|
||||
tor_data = networks.get("tor")
|
||||
assert tor_data is not None
|
||||
assert "network" in tor_data
|
||||
assert "remote" in tor_data
|
||||
assert direct_remote is not None
|
||||
assert direct_remote.password == "scabbed-defender-headlock"
|
||||
assert direct_remote.address == "192.168.122.86"
|
||||
|
||||
# Get the remote
|
||||
tor_host = tor_data["remote"]
|
||||
# Find tor connection
|
||||
tor_remote = None
|
||||
for addr in addresses:
|
||||
if addr.network.module_name == "clan_lib.network.tor":
|
||||
tor_remote = addr.remote
|
||||
break
|
||||
|
||||
assert tor_remote is not None
|
||||
assert (
|
||||
tor_host.address
|
||||
tor_remote.address
|
||||
== "qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion"
|
||||
)
|
||||
assert tor_host.socks_port == 9050
|
||||
assert tor_host.password == "scabbed-defender-headlock"
|
||||
assert tor_host.user == "root"
|
||||
assert tor_remote.socks_port == 9050
|
||||
assert tor_remote.password == "scabbed-defender-headlock"
|
||||
assert tor_remote.user == "root"
|
||||
|
||||
|
||||
@pytest.mark.with_core
|
||||
def test_find_reachable_host(hosts: list[Remote]) -> None:
|
||||
host = hosts[0]
|
||||
|
||||
uris = ["172.19.1.2", host.ssh_url()]
|
||||
remotes = [Remote.from_ssh_uri(machine_name="some", address=uri) for uri in uris]
|
||||
|
||||
assert remotes[0].address == "172.19.1.2"
|
||||
|
||||
remote = find_reachable_host(remotes=remotes)
|
||||
|
||||
assert remote is not None
|
||||
assert remote.ssh_url() == host.ssh_url()
|
||||
# TODO: This test needs to be updated to use get_best_remote from clan_lib.network.network
|
||||
# @pytest.mark.with_core
|
||||
# def test_find_reachable_host(hosts: list[Remote]) -> None:
|
||||
# host = hosts[0]
|
||||
#
|
||||
# uris = ["172.19.1.2", host.ssh_url()]
|
||||
# remotes = [Remote.from_ssh_uri(machine_name="some", address=uri) for uri in uris]
|
||||
#
|
||||
# assert remotes[0].address == "172.19.1.2"
|
||||
#
|
||||
# remote = find_reachable_host(remotes=remotes)
|
||||
#
|
||||
# assert remote is not None
|
||||
# assert remote.ssh_url() == host.ssh_url()
|
||||
|
||||
|
||||
@pytest.mark.with_core
|
||||
|
||||
@@ -10,7 +10,6 @@ from clan_cli.facts import secret_modules as facts_secret_modules
|
||||
from clan_cli.vars._types import StoreBase
|
||||
|
||||
from clan_lib.api import API
|
||||
from clan_lib.errors import ClanError
|
||||
from clan_lib.flake import ClanSelectError, Flake
|
||||
from clan_lib.nix_models.clan import InventoryMachine
|
||||
from clan_lib.ssh.remote import Remote
|
||||
@@ -125,15 +124,10 @@ class Machine:
|
||||
return self.flake.path
|
||||
|
||||
def target_host(self) -> Remote:
|
||||
remote = get_machine_host(self.name, self.flake, field="targetHost")
|
||||
if remote is None:
|
||||
msg = f"'targetHost' is not set for machine '{self.name}'"
|
||||
raise ClanError(
|
||||
msg,
|
||||
description="See https://docs.clan.lol/guides/getting-started/update/#setting-the-target-host for more information.",
|
||||
)
|
||||
data = remote.data
|
||||
return data
|
||||
from clan_lib.network.network import get_best_remote
|
||||
|
||||
with get_best_remote(self) as remote:
|
||||
return remote
|
||||
|
||||
def build_host(self) -> Remote | None:
|
||||
"""
|
||||
|
||||
@@ -19,12 +19,10 @@ class NetworkTechnology(NetworkTechnologyBase):
|
||||
"""Direct connections are always 'running' as they don't require a daemon"""
|
||||
return True
|
||||
|
||||
def ping(self, peer: Peer) -> None | float:
|
||||
def ping(self, remote: Remote) -> None | float:
|
||||
if self.is_running():
|
||||
try:
|
||||
# Parse the peer's host address to create a Remote object, use peer here since we don't have the machine_name here
|
||||
remote = Remote.from_ssh_uri(machine_name="peer", address=peer.host)
|
||||
|
||||
# Use the existing SSH reachability check
|
||||
now = time.time()
|
||||
|
||||
@@ -33,7 +31,7 @@ class NetworkTechnology(NetworkTechnologyBase):
|
||||
return (time.time() - now) * 1000
|
||||
|
||||
except ClanError as e:
|
||||
log.debug(f"Error checking peer {peer.host}: {e}")
|
||||
log.debug(f"Error checking peer {remote}: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ from clan_cli.vars.get import get_machine_var
|
||||
from clan_lib.errors import ClanError
|
||||
from clan_lib.flake import Flake
|
||||
from clan_lib.import_utils import ClassSource, import_with_source
|
||||
from clan_lib.ssh.remote import Remote
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clan_lib.machines.machines import Machine
|
||||
from clan_lib.ssh.remote import Remote
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,7 +52,7 @@ class Peer:
|
||||
.lstrip("\n")
|
||||
)
|
||||
raise ClanError(msg)
|
||||
return var.value.decode()
|
||||
return var.value.decode().strip()
|
||||
msg = f"Unknown Var Type {self._host}"
|
||||
raise ClanError(msg)
|
||||
|
||||
@@ -76,7 +76,7 @@ class Network:
|
||||
return self.module.is_running()
|
||||
|
||||
def ping(self, peer: str) -> float | None:
|
||||
return self.module.ping(self.peers[peer])
|
||||
return self.module.ping(self.remote(peer))
|
||||
|
||||
def remote(self, peer: str) -> "Remote":
|
||||
# TODO raise exception if peer is not in peers
|
||||
@@ -96,7 +96,7 @@ class NetworkTechnologyBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def ping(self, peer: Peer) -> None | float:
|
||||
def ping(self, remote: "Remote") -> None | float:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@@ -109,12 +109,18 @@ def networks_from_flake(flake: Flake) -> dict[str, Network]:
|
||||
# TODO more precaching, for example for vars
|
||||
flake.precache(
|
||||
[
|
||||
"clan.exports.instances.*.networking",
|
||||
"clan.?exports.instances.*.networking",
|
||||
]
|
||||
)
|
||||
networks: dict[str, Network] = {}
|
||||
networks_ = flake.select("clan.exports.instances.*.networking")
|
||||
for network_name, network in networks_.items():
|
||||
networks_ = flake.select("clan.?exports.instances.*.networking")
|
||||
if "exports" not in networks_:
|
||||
msg = """You are not exporting the clan exports through your flake.
|
||||
Please add exports next to clanInternals and nixosConfiguration into the global flake.
|
||||
"""
|
||||
log.warning(msg)
|
||||
return {}
|
||||
for network_name, network in networks_["exports"].items():
|
||||
if network:
|
||||
peers: dict[str, Peer] = {}
|
||||
for _peer in network["peers"].values():
|
||||
@@ -129,19 +135,8 @@ def networks_from_flake(flake: Flake) -> dict[str, Network]:
|
||||
return networks
|
||||
|
||||
|
||||
def get_best_network(machine_name: str, networks: dict[str, Network]) -> Network | None:
|
||||
for network_name, network in sorted(
|
||||
networks.items(), key=lambda network: -network[1].priority
|
||||
):
|
||||
if machine_name in network.peers:
|
||||
if network.is_running() and network.ping(machine_name):
|
||||
print(f"connecting via {network_name}")
|
||||
return network
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_remote_for_machine(machine: "Machine") -> Iterator["Remote"]:
|
||||
def get_best_remote(machine: "Machine") -> Iterator["Remote"]:
|
||||
"""
|
||||
Context manager that yields the best remote connection for a machine following this priority:
|
||||
1. If machine has targetHost in inventory, return a direct connection
|
||||
@@ -158,9 +153,6 @@ def get_remote_for_machine(machine: "Machine") -> Iterator["Remote"]:
|
||||
ClanError: If no connection method works
|
||||
"""
|
||||
|
||||
# Get networks from the flake
|
||||
networks = networks_from_flake(machine.flake)
|
||||
|
||||
# Step 1: Check if targetHost is set in inventory
|
||||
inv_machine = machine.get_inv_machine()
|
||||
target_host = inv_machine.get("deploy", {}).get("targetHost")
|
||||
@@ -176,39 +168,45 @@ def get_remote_for_machine(machine: "Machine") -> Iterator["Remote"]:
|
||||
log.debug(f"Inventory targetHost not reachable for {machine.name}: {e}")
|
||||
|
||||
# Step 2: Try existing networks by priority
|
||||
sorted_networks = sorted(networks.items(), key=lambda x: -x[1].priority)
|
||||
try:
|
||||
networks = networks_from_flake(machine.flake)
|
||||
|
||||
for network_name, network in sorted_networks:
|
||||
if machine.name not in network.peers:
|
||||
continue
|
||||
sorted_networks = sorted(networks.items(), key=lambda x: -x[1].priority)
|
||||
|
||||
# Check if network is running and machine is reachable
|
||||
if network.is_running():
|
||||
try:
|
||||
ping_time = network.ping(machine.name)
|
||||
if ping_time is not None:
|
||||
log.info(
|
||||
f"Machine {machine.name} reachable via {network_name} network"
|
||||
)
|
||||
yield network.remote(machine.name)
|
||||
return
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to reach {machine.name} via {network_name}: {e}")
|
||||
else:
|
||||
try:
|
||||
log.debug(f"Establishing connection for network {network_name}")
|
||||
with network.module.connection(network) as connected_network:
|
||||
ping_time = connected_network.ping(machine.name)
|
||||
for network_name, network in sorted_networks:
|
||||
if machine.name not in network.peers:
|
||||
continue
|
||||
|
||||
# Check if network is running and machine is reachable
|
||||
log.debug(f"trying to connect via {network_name}")
|
||||
if network.is_running():
|
||||
try:
|
||||
ping_time = network.ping(machine.name)
|
||||
if ping_time is not None:
|
||||
log.info(
|
||||
f"Machine {machine.name} reachable via {network_name} network after connection"
|
||||
f"Machine {machine.name} reachable via {network_name} network"
|
||||
)
|
||||
yield connected_network.remote(machine.name)
|
||||
yield network.remote(machine.name)
|
||||
return
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
f"Failed to establish connection to {machine.name} via {network_name}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to reach {machine.name} via {network_name}: {e}")
|
||||
else:
|
||||
try:
|
||||
log.debug(f"Establishing connection for network {network_name}")
|
||||
with network.module.connection(network) as connected_network:
|
||||
ping_time = connected_network.ping(machine.name)
|
||||
if ping_time is not None:
|
||||
log.info(
|
||||
f"Machine {machine.name} reachable via {network_name} network after connection"
|
||||
)
|
||||
yield connected_network.remote(machine.name)
|
||||
return
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
f"Failed to establish connection to {machine.name} via {network_name}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to use networking modules to determine machines remote: {e}")
|
||||
|
||||
# Step 3: Try targetHost from machine nixos config
|
||||
try:
|
||||
|
||||
@@ -26,46 +26,48 @@ def test_networks_from_flake(mock_get_machine_var: MagicMock) -> None:
|
||||
|
||||
# Define the expected return value from flake.select
|
||||
mock_networking_data = {
|
||||
"vpn-network": {
|
||||
"peers": {
|
||||
"machine1": {
|
||||
"name": "machine1",
|
||||
"host": {
|
||||
"var": {
|
||||
"machine": "machine1",
|
||||
"generator": "wireguard",
|
||||
"file": "address",
|
||||
}
|
||||
"exports": {
|
||||
"vpn-network": {
|
||||
"peers": {
|
||||
"machine1": {
|
||||
"name": "machine1",
|
||||
"host": {
|
||||
"var": {
|
||||
"machine": "machine1",
|
||||
"generator": "wireguard",
|
||||
"file": "address",
|
||||
}
|
||||
},
|
||||
},
|
||||
"machine2": {
|
||||
"name": "machine2",
|
||||
"host": {
|
||||
"var": {
|
||||
"machine": "machine2",
|
||||
"generator": "wireguard",
|
||||
"file": "address",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"machine2": {
|
||||
"name": "machine2",
|
||||
"host": {
|
||||
"var": {
|
||||
"machine": "machine2",
|
||||
"generator": "wireguard",
|
||||
"file": "address",
|
||||
}
|
||||
"module": "clan_lib.network.tor",
|
||||
"priority": 1000,
|
||||
},
|
||||
"local-network": {
|
||||
"peers": {
|
||||
"machine1": {
|
||||
"name": "machine1",
|
||||
"host": {"plain": "10.0.0.10"},
|
||||
},
|
||||
"machine3": {
|
||||
"name": "machine3",
|
||||
"host": {"plain": "10.0.0.12"},
|
||||
},
|
||||
},
|
||||
"module": "clan_lib.network.direct",
|
||||
"priority": 500,
|
||||
},
|
||||
"module": "clan_lib.network.tor",
|
||||
"priority": 1000,
|
||||
},
|
||||
"local-network": {
|
||||
"peers": {
|
||||
"machine1": {
|
||||
"name": "machine1",
|
||||
"host": {"plain": "10.0.0.10"},
|
||||
},
|
||||
"machine3": {
|
||||
"name": "machine3",
|
||||
"host": {"plain": "10.0.0.12"},
|
||||
},
|
||||
},
|
||||
"module": "clan_lib.network.direct",
|
||||
"priority": 500,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# Mock the select method
|
||||
@@ -75,7 +77,7 @@ def test_networks_from_flake(mock_get_machine_var: MagicMock) -> None:
|
||||
networks = networks_from_flake(flake)
|
||||
|
||||
# Verify the flake.select was called with the correct pattern
|
||||
flake.select.assert_called_once_with("clan.exports.instances.*.networking")
|
||||
flake.select.assert_called_once_with("clan.?exports.instances.*.networking")
|
||||
|
||||
# Verify the returned networks
|
||||
assert len(networks) == 2
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -13,9 +16,33 @@ from clan_lib.ssh.remote import Remote
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_qr_json_to_networks(
|
||||
qr_data: dict[str, Any], flake: Flake
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
@dataclass(frozen=True)
|
||||
class RemoteWithNetwork:
|
||||
network: Network
|
||||
remote: Remote
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QRCodeData:
|
||||
addresses: list[RemoteWithNetwork]
|
||||
|
||||
@contextmanager
|
||||
def get_best_remote(self) -> Iterator[Remote]:
|
||||
for address in self.addresses:
|
||||
try:
|
||||
log.debug(f"Establishing connection via {address}")
|
||||
with address.network.module.connection(
|
||||
address.network
|
||||
) as connected_network:
|
||||
ping_time = connected_network.module.ping(address.remote)
|
||||
if ping_time is not None:
|
||||
log.info(f"reachable via {address} after connection")
|
||||
yield address.remote
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to establish connection via {address}: {e}")
|
||||
|
||||
|
||||
def read_qr_json(qr_data: dict[str, Any], flake: Flake) -> QRCodeData:
|
||||
"""
|
||||
Parse QR code JSON contents and output a dict of networks with remotes.
|
||||
|
||||
@@ -45,32 +72,31 @@ def parse_qr_json_to_networks(
|
||||
}
|
||||
}
|
||||
"""
|
||||
networks: dict[str, dict[str, Any]] = {}
|
||||
addresses: list[RemoteWithNetwork] = []
|
||||
|
||||
password = qr_data.get("pass")
|
||||
|
||||
# Process clearnet addresses
|
||||
clearnet_addrs = qr_data.get("addrs", [])
|
||||
if clearnet_addrs:
|
||||
# For now, just use the first address
|
||||
addr = clearnet_addrs[0]
|
||||
if isinstance(addr, str):
|
||||
peer = Peer(name="installer", _host={"plain": addr}, flake=flake)
|
||||
network = Network(
|
||||
peers={"installer": peer},
|
||||
module_name="clan_lib.network.direct",
|
||||
priority=1000,
|
||||
)
|
||||
# Create the remote with password
|
||||
remote = Remote.from_ssh_uri(
|
||||
machine_name="installer",
|
||||
address=addr,
|
||||
).override(password=password)
|
||||
for addr in clearnet_addrs:
|
||||
if isinstance(addr, str):
|
||||
peer = Peer(name="installer", _host={"plain": addr}, flake=flake)
|
||||
network = Network(
|
||||
peers={"installer": peer},
|
||||
module_name="clan_lib.network.direct",
|
||||
priority=1000,
|
||||
)
|
||||
# Create the remote with password
|
||||
remote = Remote.from_ssh_uri(
|
||||
machine_name="installer",
|
||||
address=addr,
|
||||
).override(password=password)
|
||||
|
||||
networks["direct"] = {"network": network, "remote": remote}
|
||||
else:
|
||||
msg = f"Invalid address format: {addr}"
|
||||
raise ClanError(msg)
|
||||
addresses.append(RemoteWithNetwork(network=network, remote=remote))
|
||||
else:
|
||||
msg = f"Invalid address format: {addr}"
|
||||
raise ClanError(msg)
|
||||
|
||||
# Process tor address
|
||||
if tor_addr := qr_data.get("tor"):
|
||||
@@ -86,12 +112,12 @@ def parse_qr_json_to_networks(
|
||||
address=tor_addr,
|
||||
).override(password=password, socks_port=9050, socks_wrapper=["torify"])
|
||||
|
||||
networks["tor"] = {"network": network, "remote": remote}
|
||||
addresses.append(RemoteWithNetwork(network=network, remote=remote))
|
||||
|
||||
return networks
|
||||
return QRCodeData(addresses=addresses)
|
||||
|
||||
|
||||
def parse_qr_image_to_json(image_path: Path) -> dict[str, Any]:
|
||||
def read_qr_image(image_path: Path) -> dict[str, Any]:
|
||||
"""
|
||||
Parse a QR code image and extract the JSON data.
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
|
||||
from clan_lib.errors import ClanError
|
||||
from clan_lib.network import Network, NetworkTechnologyBase, Peer
|
||||
from clan_lib.network.tor.lib import is_tor_running, spawn_tor
|
||||
from clan_lib.ssh.remote import Remote
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clan_lib.ssh.remote import Remote
|
||||
@@ -27,11 +28,9 @@ class NetworkTechnology(NetworkTechnologyBase):
|
||||
"""Check if Tor is running by sending HTTP request to SOCKS port."""
|
||||
return is_tor_running(self.proxy)
|
||||
|
||||
def ping(self, peer: Peer) -> None | float:
|
||||
def ping(self, remote: Remote) -> None | float:
|
||||
if self.is_running():
|
||||
try:
|
||||
remote = self.remote(peer)
|
||||
|
||||
# Use the existing SSH reachability check
|
||||
now = time.time()
|
||||
remote.check_machine_ssh_reachable()
|
||||
@@ -39,7 +38,7 @@ class NetworkTechnology(NetworkTechnologyBase):
|
||||
return (time.time() - now) * 1000
|
||||
|
||||
except ClanError as e:
|
||||
log.debug(f"Error checking peer {peer.host}: {e}")
|
||||
log.debug(f"Error checking peer {remote}: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user