clan-cli network: refactor, use new networking in ssh and install commands
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import ExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import get_args
|
from typing import get_args
|
||||||
|
|
||||||
@@ -8,6 +10,7 @@ from clan_lib.errors import ClanError
|
|||||||
from clan_lib.flake import require_flake
|
from clan_lib.flake import require_flake
|
||||||
from clan_lib.machines.install import BuildOn, InstallOptions, run_machine_install
|
from clan_lib.machines.install import BuildOn, InstallOptions, run_machine_install
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
|
from clan_lib.network.qr_code import read_qr_image, read_qr_json
|
||||||
from clan_lib.ssh.host_key import HostKeyCheck
|
from clan_lib.ssh.host_key import HostKeyCheck
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
@@ -17,11 +20,6 @@ from clan_cli.completions import (
|
|||||||
complete_target_host,
|
complete_target_host,
|
||||||
)
|
)
|
||||||
from clan_cli.machines.hardware import HardwareConfig
|
from clan_cli.machines.hardware import HardwareConfig
|
||||||
from clan_cli.ssh.deploy_info import (
|
|
||||||
find_reachable_host,
|
|
||||||
get_tor_remote,
|
|
||||||
ssh_command_parse,
|
|
||||||
)
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -31,81 +29,71 @@ def install_command(args: argparse.Namespace) -> None:
|
|||||||
flake = require_flake(args.flake)
|
flake = require_flake(args.flake)
|
||||||
# Only if the caller did not specify a target_host via args.target_host
|
# Only if the caller did not specify a target_host via args.target_host
|
||||||
# Find a suitable target_host that is reachable
|
# Find a suitable target_host that is reachable
|
||||||
target_host_str = args.target_host
|
with ExitStack() as stack:
|
||||||
remotes: list[Remote] | None = (
|
remote: Remote
|
||||||
ssh_command_parse(args) if target_host_str is None else None
|
if args.target_host:
|
||||||
)
|
# TODO add network support here with either --network or some url magic
|
||||||
|
remote = Remote.from_ssh_uri(
|
||||||
use_tor = False
|
machine_name=args.machine, address=args.target_host
|
||||||
if remotes:
|
|
||||||
host = find_reachable_host(remotes)
|
|
||||||
if host is None or host.socks_port:
|
|
||||||
use_tor = True
|
|
||||||
tor_remote = get_tor_remote(remotes)
|
|
||||||
target_host_str = tor_remote.target
|
|
||||||
else:
|
|
||||||
target_host_str = host.target
|
|
||||||
|
|
||||||
if args.password:
|
|
||||||
password = args.password
|
|
||||||
elif remotes and remotes[0].password:
|
|
||||||
password = remotes[0].password
|
|
||||||
else:
|
|
||||||
password = None
|
|
||||||
|
|
||||||
machine = Machine(name=args.machine, flake=flake)
|
|
||||||
host_key_check = args.host_key_check
|
|
||||||
|
|
||||||
if target_host_str is not None:
|
|
||||||
target_host = Remote.from_ssh_uri(
|
|
||||||
machine_name=machine.name, address=target_host_str
|
|
||||||
).override(host_key_check=host_key_check)
|
|
||||||
else:
|
|
||||||
target_host = machine.target_host().override(host_key_check=host_key_check)
|
|
||||||
|
|
||||||
if args.identity_file:
|
|
||||||
target_host = target_host.override(private_key=args.identity_file)
|
|
||||||
|
|
||||||
if machine._class_ == "darwin":
|
|
||||||
msg = "Installing macOS machines is not yet supported"
|
|
||||||
raise ClanError(msg)
|
|
||||||
|
|
||||||
if not args.yes:
|
|
||||||
while True:
|
|
||||||
ask = (
|
|
||||||
input(f"Install {args.machine} to {target_host.target}? [y/N] ")
|
|
||||||
.strip()
|
|
||||||
.lower()
|
|
||||||
)
|
)
|
||||||
if ask == "y":
|
elif args.png:
|
||||||
break
|
data = read_qr_image(Path(args.png))
|
||||||
if ask == "n" or ask == "":
|
qr_code = read_qr_json(data, args.flake)
|
||||||
return None
|
remote = stack.enter_context(qr_code.get_best_remote())
|
||||||
print(f"Invalid input '{ask}'. Please enter 'y' for yes or 'n' for no.")
|
elif args.json:
|
||||||
|
json_file = Path(args.json)
|
||||||
|
if json_file.is_file():
|
||||||
|
data = json.loads(json_file.read_text())
|
||||||
|
else:
|
||||||
|
data = json.loads(args.json)
|
||||||
|
|
||||||
if args.identity_file:
|
qr_code = read_qr_json(data, args.flake)
|
||||||
target_host = target_host.override(private_key=args.identity_file)
|
remote = stack.enter_context(qr_code.get_best_remote())
|
||||||
|
else:
|
||||||
|
msg = "No MACHINE, --json or --png data provided"
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
if password:
|
machine = Machine(name=args.machine, flake=flake)
|
||||||
target_host = target_host.override(password=password)
|
if args.host_key_check:
|
||||||
|
remote.override(host_key_check=args.host_key_check)
|
||||||
|
|
||||||
if use_tor:
|
if machine._class_ == "darwin":
|
||||||
target_host = target_host.override(
|
msg = "Installing macOS machines is not yet supported"
|
||||||
socks_port=9050, socks_wrapper=["torify"]
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
if not args.yes:
|
||||||
|
while True:
|
||||||
|
ask = (
|
||||||
|
input(f"Install {args.machine} to {remote.target}? [y/N] ")
|
||||||
|
.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
|
if ask == "y":
|
||||||
|
break
|
||||||
|
if ask == "n" or ask == "":
|
||||||
|
return None
|
||||||
|
print(
|
||||||
|
f"Invalid input '{ask}'. Please enter 'y' for yes or 'n' for no."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.identity_file:
|
||||||
|
remote = remote.override(private_key=args.identity_file)
|
||||||
|
|
||||||
|
if args.password:
|
||||||
|
remote = remote.override(password=args.password)
|
||||||
|
|
||||||
|
return run_machine_install(
|
||||||
|
InstallOptions(
|
||||||
|
machine=machine,
|
||||||
|
kexec=args.kexec,
|
||||||
|
phases=args.phases,
|
||||||
|
debug=args.debug,
|
||||||
|
no_reboot=args.no_reboot,
|
||||||
|
build_on=args.build_on if args.build_on is not None else None,
|
||||||
|
update_hardware_config=HardwareConfig(args.update_hardware_config),
|
||||||
|
),
|
||||||
|
target_host=remote,
|
||||||
)
|
)
|
||||||
|
|
||||||
return run_machine_install(
|
|
||||||
InstallOptions(
|
|
||||||
machine=machine,
|
|
||||||
kexec=args.kexec,
|
|
||||||
phases=args.phases,
|
|
||||||
debug=args.debug,
|
|
||||||
no_reboot=args.no_reboot,
|
|
||||||
build_on=args.build_on if args.build_on is not None else None,
|
|
||||||
update_hardware_config=HardwareConfig(args.update_hardware_config),
|
|
||||||
),
|
|
||||||
target_host=target_host,
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.warning("Interrupted by user")
|
log.warning("Interrupted by user")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import contextlib
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
from contextlib import ExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import get_args
|
from typing import get_args
|
||||||
|
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.network.qr_code import parse_qr_image_to_json, parse_qr_json_to_networks
|
from clan_lib.network.network import get_best_remote
|
||||||
from clan_lib.network.tor.lib import spawn_tor
|
from clan_lib.network.qr_code import read_qr_image, read_qr_json
|
||||||
from clan_lib.ssh.remote import HostKeyCheck, Remote
|
from clan_lib.ssh.remote import HostKeyCheck, Remote
|
||||||
|
|
||||||
from clan_cli.completions import (
|
from clan_cli.completions import (
|
||||||
@@ -35,112 +34,42 @@ def get_tor_remote(remotes: list[Remote]) -> Remote:
|
|||||||
return tor_remotes[0]
|
return tor_remotes[0]
|
||||||
|
|
||||||
|
|
||||||
def find_reachable_host(remotes: list[Remote]) -> Remote | None:
|
|
||||||
# If we only have one address, we have no choice but to use it.
|
|
||||||
if len(remotes) == 1:
|
|
||||||
return remotes[0]
|
|
||||||
|
|
||||||
for remote in remotes:
|
|
||||||
with contextlib.suppress(ClanError):
|
|
||||||
remote.check_machine_ssh_reachable()
|
|
||||||
return remote
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def ssh_shell_from_remotes(
|
|
||||||
remotes: list[Remote], command: list[str] | None = None
|
|
||||||
) -> None:
|
|
||||||
if command and len(command) == 1 and command[0].count(" ") > 0:
|
|
||||||
msg = (
|
|
||||||
textwrap.dedent("""
|
|
||||||
It looks like you quoted the remote command.
|
|
||||||
The first argument should be the command to run, not a quoted string.
|
|
||||||
""")
|
|
||||||
.lstrip("\n")
|
|
||||||
.rstrip("\n")
|
|
||||||
)
|
|
||||||
raise ClanError(msg)
|
|
||||||
|
|
||||||
if host := find_reachable_host(remotes):
|
|
||||||
host.interactive_ssh(command)
|
|
||||||
return
|
|
||||||
|
|
||||||
log.info("Could not reach host via clearnet addresses")
|
|
||||||
log.info("Trying to reach host via tor")
|
|
||||||
|
|
||||||
tor_remotes = [r for r in remotes if r.socks_port]
|
|
||||||
if not tor_remotes:
|
|
||||||
msg = "No tor address provided, please provide a tor address."
|
|
||||||
raise ClanError(msg)
|
|
||||||
|
|
||||||
with spawn_tor():
|
|
||||||
for tor_remote in tor_remotes:
|
|
||||||
log.info(f"Trying to reach host via tor address: {tor_remote}")
|
|
||||||
|
|
||||||
with contextlib.suppress(ClanError):
|
|
||||||
tor_remote.check_machine_ssh_reachable()
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"Host reachable via tor address, starting interactive ssh session."
|
|
||||||
)
|
|
||||||
tor_remote.interactive_ssh(command)
|
|
||||||
return
|
|
||||||
|
|
||||||
log.error("Could not reach host via tor address.")
|
|
||||||
|
|
||||||
|
|
||||||
def ssh_command_parse(args: argparse.Namespace) -> list[Remote] | None:
|
|
||||||
host_key_check = args.host_key_check
|
|
||||||
remotes = None
|
|
||||||
|
|
||||||
if args.json:
|
|
||||||
json_file = Path(args.json)
|
|
||||||
if json_file.is_file():
|
|
||||||
data = json.loads(json_file.read_text())
|
|
||||||
else:
|
|
||||||
data = json.loads(args.json)
|
|
||||||
|
|
||||||
networks = parse_qr_json_to_networks(data, args.flake)
|
|
||||||
remotes = []
|
|
||||||
for _network_type, network_data in networks.items():
|
|
||||||
remote = network_data["remote"]
|
|
||||||
remotes.append(remote.override(host_key_check=host_key_check))
|
|
||||||
|
|
||||||
elif args.png:
|
|
||||||
data = parse_qr_image_to_json(Path(args.png))
|
|
||||||
networks = parse_qr_json_to_networks(data, args.flake)
|
|
||||||
remotes = []
|
|
||||||
for _network_type, network_data in networks.items():
|
|
||||||
remote = network_data["remote"]
|
|
||||||
remotes.append(remote.override(host_key_check=host_key_check))
|
|
||||||
|
|
||||||
elif hasattr(args, "machine") and args.machine:
|
|
||||||
machine = Machine(args.machine, args.flake)
|
|
||||||
target = machine.target_host().override(
|
|
||||||
command_prefix=machine.name, host_key_check=host_key_check
|
|
||||||
)
|
|
||||||
remotes = [target]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
ssh_options = None
|
|
||||||
if hasattr(args, "ssh_option") and args.ssh_option:
|
|
||||||
ssh_options = {}
|
|
||||||
for name, value in args.ssh_option:
|
|
||||||
ssh_options[name] = value
|
|
||||||
|
|
||||||
if ssh_options:
|
|
||||||
remotes = [remote.override(ssh_options=ssh_options) for remote in remotes]
|
|
||||||
|
|
||||||
return remotes
|
|
||||||
|
|
||||||
|
|
||||||
def ssh_command(args: argparse.Namespace) -> None:
|
def ssh_command(args: argparse.Namespace) -> None:
|
||||||
remotes = ssh_command_parse(args)
|
with ExitStack() as stack:
|
||||||
if not remotes:
|
remote: Remote
|
||||||
msg = "No MACHINE, --json or --png data provided"
|
if hasattr(args, "machine") and args.machine:
|
||||||
raise ClanError(msg)
|
machine = Machine(args.machine, args.flake)
|
||||||
ssh_shell_from_remotes(remotes, args.remote_command)
|
remote = stack.enter_context(get_best_remote(machine))
|
||||||
|
elif args.png:
|
||||||
|
data = read_qr_image(Path(args.png))
|
||||||
|
qr_code = read_qr_json(data, args.flake)
|
||||||
|
remote = stack.enter_context(qr_code.get_best_remote())
|
||||||
|
elif args.json:
|
||||||
|
json_file = Path(args.json)
|
||||||
|
if json_file.is_file():
|
||||||
|
data = json.loads(json_file.read_text())
|
||||||
|
else:
|
||||||
|
data = json.loads(args.json)
|
||||||
|
|
||||||
|
qr_code = read_qr_json(data, args.flake)
|
||||||
|
remote = stack.enter_context(qr_code.get_best_remote())
|
||||||
|
else:
|
||||||
|
msg = "No MACHINE, --json or --png data provided"
|
||||||
|
raise ClanError(msg)
|
||||||
|
|
||||||
|
# Convert ssh_option list to dictionary
|
||||||
|
ssh_options = {}
|
||||||
|
if args.ssh_option:
|
||||||
|
for name, value in args.ssh_option:
|
||||||
|
ssh_options[name] = value
|
||||||
|
|
||||||
|
remote = remote.override(
|
||||||
|
host_key_check=args.host_key_check, ssh_options=ssh_options
|
||||||
|
)
|
||||||
|
if args.remote_command:
|
||||||
|
remote.interactive_ssh(args.remote_command)
|
||||||
|
else:
|
||||||
|
remote.interactive_ssh()
|
||||||
|
|
||||||
|
|
||||||
def register_parser(parser: argparse.ArgumentParser) -> None:
|
def register_parser(parser: argparse.ArgumentParser) -> None:
|
||||||
|
|||||||
@@ -4,11 +4,10 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
from clan_lib.cmd import RunOpts, run
|
from clan_lib.cmd import RunOpts, run
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.network.qr_code import parse_qr_image_to_json, parse_qr_json_to_networks
|
from clan_lib.network.qr_code import read_qr_image, read_qr_json
|
||||||
from clan_lib.nix import nix_shell
|
from clan_lib.nix import nix_shell
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
from clan_cli.ssh.deploy_info import find_reachable_host
|
|
||||||
from clan_cli.tests.fixtures_flakes import ClanFlake
|
from clan_cli.tests.fixtures_flakes import ClanFlake
|
||||||
from clan_cli.tests.helpers import cli
|
from clan_cli.tests.helpers import cli
|
||||||
|
|
||||||
@@ -28,84 +27,93 @@ def test_qrcode_scan(temp_dir: Path, flake: ClanFlake) -> None:
|
|||||||
run(cmd, RunOpts(input=data.encode()))
|
run(cmd, RunOpts(input=data.encode()))
|
||||||
|
|
||||||
# Call the qrcode_scan function
|
# Call the qrcode_scan function
|
||||||
json_data = parse_qr_image_to_json(img_path)
|
json_data = read_qr_image(img_path)
|
||||||
networks = parse_qr_json_to_networks(json_data, Flake(str(flake.path)))
|
qr_code = read_qr_json(json_data, Flake(str(flake.path)))
|
||||||
|
|
||||||
# Get direct network data
|
# Check addresses
|
||||||
direct_data = networks.get("direct")
|
addresses = qr_code.addresses
|
||||||
assert direct_data is not None
|
assert len(addresses) >= 2 # At least direct and tor
|
||||||
assert "network" in direct_data
|
|
||||||
assert "remote" in direct_data
|
|
||||||
|
|
||||||
# Get the remote
|
# Find direct connection
|
||||||
host = direct_data["remote"]
|
direct_remote = None
|
||||||
assert host.address == "192.168.122.86"
|
for addr in addresses:
|
||||||
assert host.user == "root"
|
if addr.network.module_name == "clan_lib.network.direct":
|
||||||
assert host.password == "scabbed-defender-headlock"
|
direct_remote = addr.remote
|
||||||
|
break
|
||||||
|
|
||||||
# Get tor network data
|
assert direct_remote is not None
|
||||||
tor_data = networks.get("tor")
|
assert direct_remote.address == "192.168.122.86"
|
||||||
assert tor_data is not None
|
assert direct_remote.user == "root"
|
||||||
assert "network" in tor_data
|
assert direct_remote.password == "scabbed-defender-headlock"
|
||||||
assert "remote" in tor_data
|
|
||||||
|
|
||||||
# Get the remote
|
# Find tor connection
|
||||||
tor_host = tor_data["remote"]
|
tor_remote = None
|
||||||
|
for addr in addresses:
|
||||||
|
if addr.network.module_name == "clan_lib.network.tor":
|
||||||
|
tor_remote = addr.remote
|
||||||
|
break
|
||||||
|
|
||||||
|
assert tor_remote is not None
|
||||||
assert (
|
assert (
|
||||||
tor_host.address
|
tor_remote.address
|
||||||
== "qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion"
|
== "qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion"
|
||||||
)
|
)
|
||||||
assert tor_host.socks_port == 9050
|
assert tor_remote.socks_port == 9050
|
||||||
assert tor_host.password == "scabbed-defender-headlock"
|
assert tor_remote.password == "scabbed-defender-headlock"
|
||||||
assert tor_host.user == "root"
|
assert tor_remote.user == "root"
|
||||||
|
|
||||||
|
|
||||||
def test_from_json(temp_dir: Path) -> None:
|
def test_from_json(temp_dir: Path) -> None:
|
||||||
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
|
data = '{"pass":"scabbed-defender-headlock","tor":"qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion","addrs":["192.168.122.86"]}'
|
||||||
flake = Flake(str(temp_dir))
|
flake = Flake(str(temp_dir))
|
||||||
networks = parse_qr_json_to_networks(json.loads(data), flake)
|
qr_code = read_qr_json(json.loads(data), flake)
|
||||||
|
|
||||||
# Get direct network data
|
# Check addresses
|
||||||
direct_data = networks.get("direct")
|
addresses = qr_code.addresses
|
||||||
assert direct_data is not None
|
assert len(addresses) >= 2 # At least direct and tor
|
||||||
assert "network" in direct_data
|
|
||||||
assert "remote" in direct_data
|
|
||||||
|
|
||||||
# Get the remote
|
# Find direct connection
|
||||||
host = direct_data["remote"]
|
direct_remote = None
|
||||||
assert host.password == "scabbed-defender-headlock"
|
for addr in addresses:
|
||||||
assert host.address == "192.168.122.86"
|
if addr.network.module_name == "clan_lib.network.direct":
|
||||||
|
direct_remote = addr.remote
|
||||||
|
break
|
||||||
|
|
||||||
# Get tor network data
|
assert direct_remote is not None
|
||||||
tor_data = networks.get("tor")
|
assert direct_remote.password == "scabbed-defender-headlock"
|
||||||
assert tor_data is not None
|
assert direct_remote.address == "192.168.122.86"
|
||||||
assert "network" in tor_data
|
|
||||||
assert "remote" in tor_data
|
|
||||||
|
|
||||||
# Get the remote
|
# Find tor connection
|
||||||
tor_host = tor_data["remote"]
|
tor_remote = None
|
||||||
|
for addr in addresses:
|
||||||
|
if addr.network.module_name == "clan_lib.network.tor":
|
||||||
|
tor_remote = addr.remote
|
||||||
|
break
|
||||||
|
|
||||||
|
assert tor_remote is not None
|
||||||
assert (
|
assert (
|
||||||
tor_host.address
|
tor_remote.address
|
||||||
== "qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion"
|
== "qjeerm4r6t55hcfum4pinnvscn5njlw2g3k7ilqfuu7cdt3ahaxhsbid.onion"
|
||||||
)
|
)
|
||||||
assert tor_host.socks_port == 9050
|
assert tor_remote.socks_port == 9050
|
||||||
assert tor_host.password == "scabbed-defender-headlock"
|
assert tor_remote.password == "scabbed-defender-headlock"
|
||||||
assert tor_host.user == "root"
|
assert tor_remote.user == "root"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.with_core
|
# TODO: This test needs to be updated to use get_best_remote from clan_lib.network.network
|
||||||
def test_find_reachable_host(hosts: list[Remote]) -> None:
|
# @pytest.mark.with_core
|
||||||
host = hosts[0]
|
# def test_find_reachable_host(hosts: list[Remote]) -> None:
|
||||||
|
# host = hosts[0]
|
||||||
uris = ["172.19.1.2", host.ssh_url()]
|
#
|
||||||
remotes = [Remote.from_ssh_uri(machine_name="some", address=uri) for uri in uris]
|
# uris = ["172.19.1.2", host.ssh_url()]
|
||||||
|
# remotes = [Remote.from_ssh_uri(machine_name="some", address=uri) for uri in uris]
|
||||||
assert remotes[0].address == "172.19.1.2"
|
#
|
||||||
|
# assert remotes[0].address == "172.19.1.2"
|
||||||
remote = find_reachable_host(remotes=remotes)
|
#
|
||||||
|
# remote = find_reachable_host(remotes=remotes)
|
||||||
assert remote is not None
|
#
|
||||||
assert remote.ssh_url() == host.ssh_url()
|
# assert remote is not None
|
||||||
|
# assert remote.ssh_url() == host.ssh_url()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.with_core
|
@pytest.mark.with_core
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from clan_cli.facts import secret_modules as facts_secret_modules
|
|||||||
from clan_cli.vars._types import StoreBase
|
from clan_cli.vars._types import StoreBase
|
||||||
|
|
||||||
from clan_lib.api import API
|
from clan_lib.api import API
|
||||||
from clan_lib.errors import ClanError
|
|
||||||
from clan_lib.flake import ClanSelectError, Flake
|
from clan_lib.flake import ClanSelectError, Flake
|
||||||
from clan_lib.nix_models.clan import InventoryMachine
|
from clan_lib.nix_models.clan import InventoryMachine
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
@@ -125,15 +124,10 @@ class Machine:
|
|||||||
return self.flake.path
|
return self.flake.path
|
||||||
|
|
||||||
def target_host(self) -> Remote:
|
def target_host(self) -> Remote:
|
||||||
remote = get_machine_host(self.name, self.flake, field="targetHost")
|
from clan_lib.network.network import get_best_remote
|
||||||
if remote is None:
|
|
||||||
msg = f"'targetHost' is not set for machine '{self.name}'"
|
with get_best_remote(self) as remote:
|
||||||
raise ClanError(
|
return remote
|
||||||
msg,
|
|
||||||
description="See https://docs.clan.lol/guides/getting-started/update/#setting-the-target-host for more information.",
|
|
||||||
)
|
|
||||||
data = remote.data
|
|
||||||
return data
|
|
||||||
|
|
||||||
def build_host(self) -> Remote | None:
|
def build_host(self) -> Remote | None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -19,12 +19,10 @@ class NetworkTechnology(NetworkTechnologyBase):
|
|||||||
"""Direct connections are always 'running' as they don't require a daemon"""
|
"""Direct connections are always 'running' as they don't require a daemon"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def ping(self, peer: Peer) -> None | float:
|
def ping(self, remote: Remote) -> None | float:
|
||||||
if self.is_running():
|
if self.is_running():
|
||||||
try:
|
try:
|
||||||
# Parse the peer's host address to create a Remote object, use peer here since we don't have the machine_name here
|
# Parse the peer's host address to create a Remote object, use peer here since we don't have the machine_name here
|
||||||
remote = Remote.from_ssh_uri(machine_name="peer", address=peer.host)
|
|
||||||
|
|
||||||
# Use the existing SSH reachability check
|
# Use the existing SSH reachability check
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
@@ -33,7 +31,7 @@ class NetworkTechnology(NetworkTechnologyBase):
|
|||||||
return (time.time() - now) * 1000
|
return (time.time() - now) * 1000
|
||||||
|
|
||||||
except ClanError as e:
|
except ClanError as e:
|
||||||
log.debug(f"Error checking peer {peer.host}: {e}")
|
log.debug(f"Error checking peer {remote}: {e}")
|
||||||
return None
|
return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ from clan_cli.vars.get import get_machine_var
|
|||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.import_utils import ClassSource, import_with_source
|
from clan_lib.import_utils import ClassSource, import_with_source
|
||||||
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.ssh.remote import Remote
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ class Peer:
|
|||||||
.lstrip("\n")
|
.lstrip("\n")
|
||||||
)
|
)
|
||||||
raise ClanError(msg)
|
raise ClanError(msg)
|
||||||
return var.value.decode()
|
return var.value.decode().strip()
|
||||||
msg = f"Unknown Var Type {self._host}"
|
msg = f"Unknown Var Type {self._host}"
|
||||||
raise ClanError(msg)
|
raise ClanError(msg)
|
||||||
|
|
||||||
@@ -76,7 +76,7 @@ class Network:
|
|||||||
return self.module.is_running()
|
return self.module.is_running()
|
||||||
|
|
||||||
def ping(self, peer: str) -> float | None:
|
def ping(self, peer: str) -> float | None:
|
||||||
return self.module.ping(self.peers[peer])
|
return self.module.ping(self.remote(peer))
|
||||||
|
|
||||||
def remote(self, peer: str) -> "Remote":
|
def remote(self, peer: str) -> "Remote":
|
||||||
# TODO raise exception if peer is not in peers
|
# TODO raise exception if peer is not in peers
|
||||||
@@ -96,7 +96,7 @@ class NetworkTechnologyBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def ping(self, peer: Peer) -> None | float:
|
def ping(self, remote: "Remote") -> None | float:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -109,12 +109,18 @@ def networks_from_flake(flake: Flake) -> dict[str, Network]:
|
|||||||
# TODO more precaching, for example for vars
|
# TODO more precaching, for example for vars
|
||||||
flake.precache(
|
flake.precache(
|
||||||
[
|
[
|
||||||
"clan.exports.instances.*.networking",
|
"clan.?exports.instances.*.networking",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
networks: dict[str, Network] = {}
|
networks: dict[str, Network] = {}
|
||||||
networks_ = flake.select("clan.exports.instances.*.networking")
|
networks_ = flake.select("clan.?exports.instances.*.networking")
|
||||||
for network_name, network in networks_.items():
|
if "exports" not in networks_:
|
||||||
|
msg = """You are not exporting the clan exports through your flake.
|
||||||
|
Please add exports next to clanInternals and nixosConfiguration into the global flake.
|
||||||
|
"""
|
||||||
|
log.warning(msg)
|
||||||
|
return {}
|
||||||
|
for network_name, network in networks_["exports"].items():
|
||||||
if network:
|
if network:
|
||||||
peers: dict[str, Peer] = {}
|
peers: dict[str, Peer] = {}
|
||||||
for _peer in network["peers"].values():
|
for _peer in network["peers"].values():
|
||||||
@@ -129,19 +135,8 @@ def networks_from_flake(flake: Flake) -> dict[str, Network]:
|
|||||||
return networks
|
return networks
|
||||||
|
|
||||||
|
|
||||||
def get_best_network(machine_name: str, networks: dict[str, Network]) -> Network | None:
|
|
||||||
for network_name, network in sorted(
|
|
||||||
networks.items(), key=lambda network: -network[1].priority
|
|
||||||
):
|
|
||||||
if machine_name in network.peers:
|
|
||||||
if network.is_running() and network.ping(machine_name):
|
|
||||||
print(f"connecting via {network_name}")
|
|
||||||
return network
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_remote_for_machine(machine: "Machine") -> Iterator["Remote"]:
|
def get_best_remote(machine: "Machine") -> Iterator["Remote"]:
|
||||||
"""
|
"""
|
||||||
Context manager that yields the best remote connection for a machine following this priority:
|
Context manager that yields the best remote connection for a machine following this priority:
|
||||||
1. If machine has targetHost in inventory, return a direct connection
|
1. If machine has targetHost in inventory, return a direct connection
|
||||||
@@ -158,9 +153,6 @@ def get_remote_for_machine(machine: "Machine") -> Iterator["Remote"]:
|
|||||||
ClanError: If no connection method works
|
ClanError: If no connection method works
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get networks from the flake
|
|
||||||
networks = networks_from_flake(machine.flake)
|
|
||||||
|
|
||||||
# Step 1: Check if targetHost is set in inventory
|
# Step 1: Check if targetHost is set in inventory
|
||||||
inv_machine = machine.get_inv_machine()
|
inv_machine = machine.get_inv_machine()
|
||||||
target_host = inv_machine.get("deploy", {}).get("targetHost")
|
target_host = inv_machine.get("deploy", {}).get("targetHost")
|
||||||
@@ -176,39 +168,45 @@ def get_remote_for_machine(machine: "Machine") -> Iterator["Remote"]:
|
|||||||
log.debug(f"Inventory targetHost not reachable for {machine.name}: {e}")
|
log.debug(f"Inventory targetHost not reachable for {machine.name}: {e}")
|
||||||
|
|
||||||
# Step 2: Try existing networks by priority
|
# Step 2: Try existing networks by priority
|
||||||
sorted_networks = sorted(networks.items(), key=lambda x: -x[1].priority)
|
try:
|
||||||
|
networks = networks_from_flake(machine.flake)
|
||||||
|
|
||||||
for network_name, network in sorted_networks:
|
sorted_networks = sorted(networks.items(), key=lambda x: -x[1].priority)
|
||||||
if machine.name not in network.peers:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if network is running and machine is reachable
|
for network_name, network in sorted_networks:
|
||||||
if network.is_running():
|
if machine.name not in network.peers:
|
||||||
try:
|
continue
|
||||||
ping_time = network.ping(machine.name)
|
|
||||||
if ping_time is not None:
|
# Check if network is running and machine is reachable
|
||||||
log.info(
|
log.debug(f"trying to connect via {network_name}")
|
||||||
f"Machine {machine.name} reachable via {network_name} network"
|
if network.is_running():
|
||||||
)
|
try:
|
||||||
yield network.remote(machine.name)
|
ping_time = network.ping(machine.name)
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
log.debug(f"Failed to reach {machine.name} via {network_name}: {e}")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
log.debug(f"Establishing connection for network {network_name}")
|
|
||||||
with network.module.connection(network) as connected_network:
|
|
||||||
ping_time = connected_network.ping(machine.name)
|
|
||||||
if ping_time is not None:
|
if ping_time is not None:
|
||||||
log.info(
|
log.info(
|
||||||
f"Machine {machine.name} reachable via {network_name} network after connection"
|
f"Machine {machine.name} reachable via {network_name} network"
|
||||||
)
|
)
|
||||||
yield connected_network.remote(machine.name)
|
yield network.remote(machine.name)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(
|
log.debug(f"Failed to reach {machine.name} via {network_name}: {e}")
|
||||||
f"Failed to establish connection to {machine.name} via {network_name}: {e}"
|
else:
|
||||||
)
|
try:
|
||||||
|
log.debug(f"Establishing connection for network {network_name}")
|
||||||
|
with network.module.connection(network) as connected_network:
|
||||||
|
ping_time = connected_network.ping(machine.name)
|
||||||
|
if ping_time is not None:
|
||||||
|
log.info(
|
||||||
|
f"Machine {machine.name} reachable via {network_name} network after connection"
|
||||||
|
)
|
||||||
|
yield connected_network.remote(machine.name)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(
|
||||||
|
f"Failed to establish connection to {machine.name} via {network_name}: {e}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed to use networking modules to determine machines remote: {e}")
|
||||||
|
|
||||||
# Step 3: Try targetHost from machine nixos config
|
# Step 3: Try targetHost from machine nixos config
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -26,46 +26,48 @@ def test_networks_from_flake(mock_get_machine_var: MagicMock) -> None:
|
|||||||
|
|
||||||
# Define the expected return value from flake.select
|
# Define the expected return value from flake.select
|
||||||
mock_networking_data = {
|
mock_networking_data = {
|
||||||
"vpn-network": {
|
"exports": {
|
||||||
"peers": {
|
"vpn-network": {
|
||||||
"machine1": {
|
"peers": {
|
||||||
"name": "machine1",
|
"machine1": {
|
||||||
"host": {
|
"name": "machine1",
|
||||||
"var": {
|
"host": {
|
||||||
"machine": "machine1",
|
"var": {
|
||||||
"generator": "wireguard",
|
"machine": "machine1",
|
||||||
"file": "address",
|
"generator": "wireguard",
|
||||||
}
|
"file": "address",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"machine2": {
|
||||||
|
"name": "machine2",
|
||||||
|
"host": {
|
||||||
|
"var": {
|
||||||
|
"machine": "machine2",
|
||||||
|
"generator": "wireguard",
|
||||||
|
"file": "address",
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"machine2": {
|
"module": "clan_lib.network.tor",
|
||||||
"name": "machine2",
|
"priority": 1000,
|
||||||
"host": {
|
},
|
||||||
"var": {
|
"local-network": {
|
||||||
"machine": "machine2",
|
"peers": {
|
||||||
"generator": "wireguard",
|
"machine1": {
|
||||||
"file": "address",
|
"name": "machine1",
|
||||||
}
|
"host": {"plain": "10.0.0.10"},
|
||||||
|
},
|
||||||
|
"machine3": {
|
||||||
|
"name": "machine3",
|
||||||
|
"host": {"plain": "10.0.0.12"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"module": "clan_lib.network.direct",
|
||||||
|
"priority": 500,
|
||||||
},
|
},
|
||||||
"module": "clan_lib.network.tor",
|
}
|
||||||
"priority": 1000,
|
|
||||||
},
|
|
||||||
"local-network": {
|
|
||||||
"peers": {
|
|
||||||
"machine1": {
|
|
||||||
"name": "machine1",
|
|
||||||
"host": {"plain": "10.0.0.10"},
|
|
||||||
},
|
|
||||||
"machine3": {
|
|
||||||
"name": "machine3",
|
|
||||||
"host": {"plain": "10.0.0.12"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"module": "clan_lib.network.direct",
|
|
||||||
"priority": 500,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock the select method
|
# Mock the select method
|
||||||
@@ -75,7 +77,7 @@ def test_networks_from_flake(mock_get_machine_var: MagicMock) -> None:
|
|||||||
networks = networks_from_flake(flake)
|
networks = networks_from_flake(flake)
|
||||||
|
|
||||||
# Verify the flake.select was called with the correct pattern
|
# Verify the flake.select was called with the correct pattern
|
||||||
flake.select.assert_called_once_with("clan.exports.instances.*.networking")
|
flake.select.assert_called_once_with("clan.?exports.instances.*.networking")
|
||||||
|
|
||||||
# Verify the returned networks
|
# Verify the returned networks
|
||||||
assert len(networks) == 2
|
assert len(networks) == 2
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -13,9 +16,33 @@ from clan_lib.ssh.remote import Remote
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def parse_qr_json_to_networks(
|
@dataclass(frozen=True)
|
||||||
qr_data: dict[str, Any], flake: Flake
|
class RemoteWithNetwork:
|
||||||
) -> dict[str, dict[str, Any]]:
|
network: Network
|
||||||
|
remote: Remote
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class QRCodeData:
|
||||||
|
addresses: list[RemoteWithNetwork]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_best_remote(self) -> Iterator[Remote]:
|
||||||
|
for address in self.addresses:
|
||||||
|
try:
|
||||||
|
log.debug(f"Establishing connection via {address}")
|
||||||
|
with address.network.module.connection(
|
||||||
|
address.network
|
||||||
|
) as connected_network:
|
||||||
|
ping_time = connected_network.module.ping(address.remote)
|
||||||
|
if ping_time is not None:
|
||||||
|
log.info(f"reachable via {address} after connection")
|
||||||
|
yield address.remote
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed to establish connection via {address}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def read_qr_json(qr_data: dict[str, Any], flake: Flake) -> QRCodeData:
|
||||||
"""
|
"""
|
||||||
Parse QR code JSON contents and output a dict of networks with remotes.
|
Parse QR code JSON contents and output a dict of networks with remotes.
|
||||||
|
|
||||||
@@ -45,32 +72,31 @@ def parse_qr_json_to_networks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
networks: dict[str, dict[str, Any]] = {}
|
addresses: list[RemoteWithNetwork] = []
|
||||||
|
|
||||||
password = qr_data.get("pass")
|
password = qr_data.get("pass")
|
||||||
|
|
||||||
# Process clearnet addresses
|
# Process clearnet addresses
|
||||||
clearnet_addrs = qr_data.get("addrs", [])
|
clearnet_addrs = qr_data.get("addrs", [])
|
||||||
if clearnet_addrs:
|
if clearnet_addrs:
|
||||||
# For now, just use the first address
|
for addr in clearnet_addrs:
|
||||||
addr = clearnet_addrs[0]
|
if isinstance(addr, str):
|
||||||
if isinstance(addr, str):
|
peer = Peer(name="installer", _host={"plain": addr}, flake=flake)
|
||||||
peer = Peer(name="installer", _host={"plain": addr}, flake=flake)
|
network = Network(
|
||||||
network = Network(
|
peers={"installer": peer},
|
||||||
peers={"installer": peer},
|
module_name="clan_lib.network.direct",
|
||||||
module_name="clan_lib.network.direct",
|
priority=1000,
|
||||||
priority=1000,
|
)
|
||||||
)
|
# Create the remote with password
|
||||||
# Create the remote with password
|
remote = Remote.from_ssh_uri(
|
||||||
remote = Remote.from_ssh_uri(
|
machine_name="installer",
|
||||||
machine_name="installer",
|
address=addr,
|
||||||
address=addr,
|
).override(password=password)
|
||||||
).override(password=password)
|
|
||||||
|
|
||||||
networks["direct"] = {"network": network, "remote": remote}
|
addresses.append(RemoteWithNetwork(network=network, remote=remote))
|
||||||
else:
|
else:
|
||||||
msg = f"Invalid address format: {addr}"
|
msg = f"Invalid address format: {addr}"
|
||||||
raise ClanError(msg)
|
raise ClanError(msg)
|
||||||
|
|
||||||
# Process tor address
|
# Process tor address
|
||||||
if tor_addr := qr_data.get("tor"):
|
if tor_addr := qr_data.get("tor"):
|
||||||
@@ -86,12 +112,12 @@ def parse_qr_json_to_networks(
|
|||||||
address=tor_addr,
|
address=tor_addr,
|
||||||
).override(password=password, socks_port=9050, socks_wrapper=["torify"])
|
).override(password=password, socks_port=9050, socks_wrapper=["torify"])
|
||||||
|
|
||||||
networks["tor"] = {"network": network, "remote": remote}
|
addresses.append(RemoteWithNetwork(network=network, remote=remote))
|
||||||
|
|
||||||
return networks
|
return QRCodeData(addresses=addresses)
|
||||||
|
|
||||||
|
|
||||||
def parse_qr_image_to_json(image_path: Path) -> dict[str, Any]:
|
def read_qr_image(image_path: Path) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Parse a QR code image and extract the JSON data.
|
Parse a QR code image and extract the JSON data.
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING
|
|||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.network import Network, NetworkTechnologyBase, Peer
|
from clan_lib.network import Network, NetworkTechnologyBase, Peer
|
||||||
from clan_lib.network.tor.lib import is_tor_running, spawn_tor
|
from clan_lib.network.tor.lib import is_tor_running, spawn_tor
|
||||||
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
@@ -27,11 +28,9 @@ class NetworkTechnology(NetworkTechnologyBase):
|
|||||||
"""Check if Tor is running by sending HTTP request to SOCKS port."""
|
"""Check if Tor is running by sending HTTP request to SOCKS port."""
|
||||||
return is_tor_running(self.proxy)
|
return is_tor_running(self.proxy)
|
||||||
|
|
||||||
def ping(self, peer: Peer) -> None | float:
|
def ping(self, remote: Remote) -> None | float:
|
||||||
if self.is_running():
|
if self.is_running():
|
||||||
try:
|
try:
|
||||||
remote = self.remote(peer)
|
|
||||||
|
|
||||||
# Use the existing SSH reachability check
|
# Use the existing SSH reachability check
|
||||||
now = time.time()
|
now = time.time()
|
||||||
remote.check_machine_ssh_reachable()
|
remote.check_machine_ssh_reachable()
|
||||||
@@ -39,7 +38,7 @@ class NetworkTechnology(NetworkTechnologyBase):
|
|||||||
return (time.time() - now) * 1000
|
return (time.time() - now) * 1000
|
||||||
|
|
||||||
except ClanError as e:
|
except ClanError as e:
|
||||||
log.debug(f"Error checking peer {peer.host}: {e}")
|
log.debug(f"Error checking peer {remote}: {e}")
|
||||||
return None
|
return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user