ssh: refactor callers to use new Host interface
This commit is contained in:
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import clan_lib.machines.machines as machines
|
import clan_lib.machines.machines as machines
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
|
|
||||||
class SecretStoreBase(ABC):
|
class SecretStoreBase(ABC):
|
||||||
@@ -26,7 +26,7 @@ class SecretStoreBase(ABC):
|
|||||||
def exists(self, service: str, name: str) -> bool:
|
def exists(self, service: str, name: str) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def needs_upload(self, host: Remote) -> bool:
|
def needs_upload(self, host: Host) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import override
|
|||||||
from clan_lib.cmd import Log, RunOpts
|
from clan_lib.cmd import Log, RunOpts
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.nix import nix_shell
|
from clan_lib.nix import nix_shell
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
from clan_cli.facts.secret_modules import SecretStoreBase
|
from clan_cli.facts.secret_modules import SecretStoreBase
|
||||||
|
|
||||||
@@ -94,9 +94,9 @@ class SecretStore(SecretStoreBase):
|
|||||||
return b"\n".join(hashes)
|
return b"\n".join(hashes)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def needs_upload(self, host: Remote) -> bool:
|
def needs_upload(self, host: Host) -> bool:
|
||||||
local_hash = self.generate_hash()
|
local_hash = self.generate_hash()
|
||||||
with host.ssh_control_master() as ssh:
|
with host.host_connection() as ssh:
|
||||||
remote_hash = ssh.run(
|
remote_hash = ssh.run(
|
||||||
# TODO get the path to the secrets from the machine
|
# TODO get the path to the secrets from the machine
|
||||||
["cat", f"{self.machine.secrets_upload_directory}/.pass_info"],
|
["cat", f"{self.machine.secrets_upload_directory}/.pass_info"],
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from pathlib import Path
|
|||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
from clan_cli.secrets.folders import sops_secrets_folder
|
from clan_cli.secrets.folders import sops_secrets_folder
|
||||||
from clan_cli.secrets.machines import add_machine, has_machine
|
from clan_cli.secrets.machines import add_machine, has_machine
|
||||||
@@ -64,7 +64,7 @@ class SecretStore(SecretStoreBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def needs_upload(self, host: Remote) -> bool:
|
def needs_upload(self, host: Host) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# We rely now on the vars backend to upload the age key
|
# We rely now on the vars backend to upload the age key
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory
|
|||||||
|
|
||||||
from clan_lib.flake import require_flake
|
from clan_lib.flake import require_flake
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
from clan_cli.ssh.upload import upload
|
from clan_cli.ssh.upload import upload
|
||||||
@@ -13,7 +13,7 @@ from clan_cli.ssh.upload import upload
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def upload_secrets(machine: Machine, host: Remote) -> None:
|
def upload_secrets(machine: Machine, host: Host) -> None:
|
||||||
if not machine.secret_facts_store.needs_upload(host):
|
if not machine.secret_facts_store.needs_upload(host):
|
||||||
machine.info("Secrets already uploaded")
|
machine.info("Secrets already uploaded")
|
||||||
return
|
return
|
||||||
@@ -28,7 +28,7 @@ def upload_secrets(machine: Machine, host: Remote) -> None:
|
|||||||
def upload_command(args: argparse.Namespace) -> None:
|
def upload_command(args: argparse.Namespace) -> None:
|
||||||
flake = require_flake(args.flake)
|
flake = require_flake(args.flake)
|
||||||
machine = Machine(name=args.machine, flake=flake)
|
machine = Machine(name=args.machine, flake=flake)
|
||||||
with machine.target_host().ssh_control_master() as host:
|
with machine.target_host().host_connection() as host:
|
||||||
upload_secrets(machine, host)
|
upload_secrets(machine, host)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ from tempfile import TemporaryDirectory
|
|||||||
|
|
||||||
from clan_lib.cmd import Log, RunOpts
|
from clan_lib.cmd import Log, RunOpts
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
|
|
||||||
def upload(
|
def upload(
|
||||||
host: Remote,
|
host: Host,
|
||||||
local_src: Path,
|
local_src: Path,
|
||||||
remote_dest: Path, # must be a directory
|
remote_dest: Path, # must be a directory
|
||||||
file_user: str = "root",
|
file_user: str = "root",
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ def hosts(sshd: Sshd) -> list[Remote]:
|
|||||||
login = pwd.getpwuid(os.getuid()).pw_name
|
login = pwd.getpwuid(os.getuid()).pw_name
|
||||||
group = [
|
group = [
|
||||||
Remote(
|
Remote(
|
||||||
"127.0.0.1",
|
address="127.0.0.1",
|
||||||
port=sshd.port,
|
port=sshd.port,
|
||||||
user=login,
|
user=login,
|
||||||
private_key=Path(sshd.key),
|
private_key=Path(sshd.key),
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from clan_lib.async_run import AsyncRuntime
|
|||||||
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
|
from clan_lib.cmd import ClanCmdTimeoutError, Log, RunOpts
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
host = Remote("some_host", user="root", command_prefix="local_test")
|
host = Remote(address="some_host", user="root", command_prefix="local_test")
|
||||||
|
|
||||||
|
|
||||||
def test_run_environment(runtime: AsyncRuntime) -> None:
|
def test_run_environment(runtime: AsyncRuntime) -> None:
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ def test_upload_single_file(
|
|||||||
src_file = temporary_home / "test.txt"
|
src_file = temporary_home / "test.txt"
|
||||||
src_file.write_text("test")
|
src_file.write_text("test")
|
||||||
dest_file = temporary_home / "test_dest.txt"
|
dest_file = temporary_home / "test_dest.txt"
|
||||||
with host.ssh_control_master() as host:
|
with host.host_connection() as host:
|
||||||
upload(host, src_file, dest_file)
|
upload(host, src_file, dest_file)
|
||||||
|
|
||||||
assert dest_file.exists()
|
assert dest_file.exists()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .generate import Generator, Var
|
from .generate import Generator, Var
|
||||||
@@ -200,5 +200,5 @@ class StoreBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def upload(self, machine: str, host: Remote, phases: list[str]) -> None:
|
def upload(self, machine: str, host: Host, phases: list[str]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from clan_cli.vars._types import StoreBase
|
|||||||
from clan_cli.vars.generate import Generator, Var
|
from clan_cli.vars.generate import Generator, Var
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
|
|
||||||
class FactStore(StoreBase):
|
class FactStore(StoreBase):
|
||||||
@@ -73,6 +73,6 @@ class FactStore(StoreBase):
|
|||||||
msg = "populate_dir is not implemented for public vars stores"
|
msg = "populate_dir is not implemented for public vars stores"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
def upload(self, machine: str, host: Remote, phases: list[str]) -> None:
|
def upload(self, machine: str, host: Host, phases: list[str]) -> None:
|
||||||
msg = "upload is not implemented for public vars stores"
|
msg = "upload is not implemented for public vars stores"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from clan_cli.vars.generate import Generator, Var
|
|||||||
from clan_lib.dirs import vm_state_dir
|
from clan_lib.dirs import vm_state_dir
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -82,6 +82,6 @@ class FactStore(StoreBase):
|
|||||||
msg = "populate_dir is not implemented for public vars stores"
|
msg = "populate_dir is not implemented for public vars stores"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
def upload(self, machine: str, host: Remote, phases: list[str]) -> None:
|
def upload(self, machine: str, host: Host, phases: list[str]) -> None:
|
||||||
msg = "upload is not implemented for public vars stores"
|
msg = "upload is not implemented for public vars stores"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
from clan_cli.vars._types import StoreBase
|
from clan_cli.vars._types import StoreBase
|
||||||
from clan_cli.vars.generate import Generator, Var
|
from clan_cli.vars.generate import Generator, Var
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
|
|
||||||
class SecretStore(StoreBase):
|
class SecretStore(StoreBase):
|
||||||
@@ -57,6 +57,6 @@ class SecretStore(StoreBase):
|
|||||||
shutil.rmtree(self.dir)
|
shutil.rmtree(self.dir)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def upload(self, machine: str, host: Remote, phases: list[str]) -> None:
|
def upload(self, machine: str, host: Host, phases: list[str]) -> None:
|
||||||
msg = "Cannot upload secrets with FS backend"
|
msg = "Cannot upload secrets with FS backend"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from clan_cli.ssh.upload import upload
|
|||||||
from clan_cli.vars._types import StoreBase
|
from clan_cli.vars._types import StoreBase
|
||||||
from clan_cli.vars.generate import Generator, Var
|
from clan_cli.vars.generate import Generator, Var
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -157,7 +157,7 @@ class SecretStore(StoreBase):
|
|||||||
manifest.append(git_hash)
|
manifest.append(git_hash)
|
||||||
return b"\n".join(manifest)
|
return b"\n".join(manifest)
|
||||||
|
|
||||||
def needs_upload(self, machine: str, host: Remote) -> bool:
|
def needs_upload(self, machine: str, host: Host) -> bool:
|
||||||
local_hash = self.generate_hash(machine)
|
local_hash = self.generate_hash(machine)
|
||||||
if not local_hash:
|
if not local_hash:
|
||||||
return True
|
return True
|
||||||
@@ -243,7 +243,7 @@ class SecretStore(StoreBase):
|
|||||||
if hash_data:
|
if hash_data:
|
||||||
(output_dir / ".pass_info").write_bytes(hash_data)
|
(output_dir / ".pass_info").write_bytes(hash_data)
|
||||||
|
|
||||||
def upload(self, machine: str, host: Remote, phases: list[str]) -> None:
|
def upload(self, machine: str, host: Host, phases: list[str]) -> None:
|
||||||
if "partitioning" in phases:
|
if "partitioning" in phases:
|
||||||
msg = "Cannot upload partitioning secrets"
|
msg = "Cannot upload partitioning secrets"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from clan_cli.vars.generate import Generator
|
|||||||
from clan_cli.vars.var import Var
|
from clan_cli.vars.var import Var
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -246,7 +246,7 @@ class SecretStore(StoreBase):
|
|||||||
target_path.chmod(file.mode)
|
target_path.chmod(file.mode)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def upload(self, machine: str, host: Remote, phases: list[str]) -> None:
|
def upload(self, machine: str, host: Host, phases: list[str]) -> None:
|
||||||
if "partitioning" in phases:
|
if "partitioning" in phases:
|
||||||
msg = "Cannot upload partitioning secrets"
|
msg = "Cannot upload partitioning secrets"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from clan_cli.vars._types import StoreBase
|
|||||||
from clan_cli.vars.generate import Generator, Var
|
from clan_cli.vars.generate import Generator, Var
|
||||||
from clan_lib.dirs import vm_state_dir
|
from clan_lib.dirs import vm_state_dir
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
|
|
||||||
class SecretStore(StoreBase):
|
class SecretStore(StoreBase):
|
||||||
@@ -71,6 +71,6 @@ class SecretStore(StoreBase):
|
|||||||
shutil.rmtree(output_dir)
|
shutil.rmtree(output_dir)
|
||||||
shutil.copytree(vars_dir, output_dir)
|
shutil.copytree(vars_dir, output_dir)
|
||||||
|
|
||||||
def upload(self, machine: str, host: Remote, phases: list[str]) -> None:
|
def upload(self, machine: str, host: Host, phases: list[str]) -> None:
|
||||||
msg = "Cannot upload secrets to VMs"
|
msg = "Cannot upload secrets to VMs"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ from pathlib import Path
|
|||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
from clan_lib.flake import require_flake
|
from clan_lib.flake import require_flake
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.host import Host
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def upload_secret_vars(machine: Machine, host: Remote) -> None:
|
def upload_secret_vars(machine: Machine, host: Host) -> None:
|
||||||
machine.secret_vars_store.upload(
|
machine.secret_vars_store.upload(
|
||||||
machine.name, host, phases=["activation", "users", "services"]
|
machine.name, host, phases=["activation", "users", "services"]
|
||||||
)
|
)
|
||||||
@@ -32,7 +32,7 @@ def upload_command(args: argparse.Namespace) -> None:
|
|||||||
populate_secret_vars(machine, directory)
|
populate_secret_vars(machine, directory)
|
||||||
return
|
return
|
||||||
|
|
||||||
with machine.target_host().ssh_control_master() as host, host.become_root() as host:
|
with machine.target_host().host_connection() as host, host.become_root() as host:
|
||||||
upload_secret_vars(machine, host)
|
upload_secret_vars(machine, host)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ def create_backup(machine: Machine, provider: str | None = None) -> None:
|
|||||||
if not backup_scripts["providers"]:
|
if not backup_scripts["providers"]:
|
||||||
msg = "No providers specified"
|
msg = "No providers specified"
|
||||||
raise ClanError(msg)
|
raise ClanError(msg)
|
||||||
with host.ssh_control_master() as ssh:
|
with host.host_connection() as ssh:
|
||||||
for provider in backup_scripts["providers"]:
|
for provider in backup_scripts["providers"]:
|
||||||
proc = ssh.run(
|
proc = ssh.run(
|
||||||
[backup_scripts["providers"][provider]["create"]],
|
[backup_scripts["providers"][provider]["create"]],
|
||||||
@@ -23,7 +23,7 @@ def create_backup(machine: Machine, provider: str | None = None) -> None:
|
|||||||
if provider not in backup_scripts["providers"]:
|
if provider not in backup_scripts["providers"]:
|
||||||
msg = f"provider {provider} not found"
|
msg = f"provider {provider} not found"
|
||||||
raise ClanError(msg)
|
raise ClanError(msg)
|
||||||
with host.ssh_control_master() as ssh:
|
with host.host_connection() as ssh:
|
||||||
proc = ssh.run(
|
proc = ssh.run(
|
||||||
[backup_scripts["providers"][provider]["create"]],
|
[backup_scripts["providers"][provider]["create"]],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def list_provider(machine: Machine, host: Remote, provider: str) -> list[Backup]
|
|||||||
def list_backups(machine: Machine, provider: str | None = None) -> list[Backup]:
|
def list_backups(machine: Machine, provider: str | None = None) -> list[Backup]:
|
||||||
backup_metadata = machine.select("config.clan.core.backups")
|
backup_metadata = machine.select("config.clan.core.backups")
|
||||||
results = []
|
results = []
|
||||||
with machine.target_host().ssh_control_master() as host:
|
with machine.target_host().host_connection() as host:
|
||||||
if provider is None:
|
if provider is None:
|
||||||
for _provider in backup_metadata["providers"]:
|
for _provider in backup_metadata["providers"]:
|
||||||
results += list_provider(machine, host, _provider)
|
results += list_provider(machine, host, _provider)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ def restore_service(
|
|||||||
# FIXME: If we have too many folder this might overflow the stack.
|
# FIXME: If we have too many folder this might overflow the stack.
|
||||||
env["FOLDERS"] = ":".join(set(folders))
|
env["FOLDERS"] = ":".join(set(folders))
|
||||||
|
|
||||||
with host.ssh_control_master() as ssh:
|
with host.host_connection() as ssh:
|
||||||
if pre_restore := backup_folders[service]["preRestoreCommand"]:
|
if pre_restore := backup_folders[service]["preRestoreCommand"]:
|
||||||
proc = ssh.run(
|
proc = ssh.run(
|
||||||
[pre_restore],
|
[pre_restore],
|
||||||
@@ -58,7 +58,7 @@ def restore_backup(
|
|||||||
service: str | None = None,
|
service: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
errors = []
|
errors = []
|
||||||
with machine.target_host().ssh_control_master() as host:
|
with machine.target_host().host_connection() as host:
|
||||||
if service is None:
|
if service is None:
|
||||||
backup_folders = machine.select("config.clan.core.state")
|
backup_folders = machine.select("config.clan.core.state")
|
||||||
for _service in backup_folders:
|
for _service in backup_folders:
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ def run_machine_hardware_info(
|
|||||||
"--show-hardware-config",
|
"--show-hardware-config",
|
||||||
]
|
]
|
||||||
|
|
||||||
with target_host.ssh_control_master() as ssh, ssh.become_root() as sudo_ssh:
|
with target_host.host_connection() as ssh, ssh.become_root() as sudo_ssh:
|
||||||
out = sudo_ssh.run(config_command, opts=RunOpts(check=False))
|
out = sudo_ssh.run(config_command, opts=RunOpts(check=False))
|
||||||
if out.returncode != 0:
|
if out.returncode != 0:
|
||||||
if "nixos-facter" in out.stderr and "not found" in out.stderr:
|
if "nixos-facter" in out.stderr and "not found" in out.stderr:
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from clan_lib.colors import AnsiColor
|
|||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
from clan_lib.machines.machines import Machine
|
from clan_lib.machines.machines import Machine
|
||||||
from clan_lib.nix import nix_command, nix_metadata
|
from clan_lib.nix import nix_command, nix_metadata
|
||||||
|
from clan_lib.ssh.host import Host
|
||||||
from clan_lib.ssh.remote import Remote
|
from clan_lib.ssh.remote import Remote
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -37,7 +38,7 @@ def is_local_input(node: dict[str, dict[str, str]]) -> bool:
|
|||||||
return local
|
return local
|
||||||
|
|
||||||
|
|
||||||
def upload_sources(machine: Machine, ssh: Remote, force_fetch_local: bool) -> str:
|
def upload_sources(machine: Machine, ssh: Host, force_fetch_local: bool) -> str:
|
||||||
env = ssh.nix_ssh_env(os.environ.copy())
|
env = ssh.nix_ssh_env(os.environ.copy())
|
||||||
|
|
||||||
flake_url = (
|
flake_url = (
|
||||||
@@ -110,8 +111,8 @@ def upload_sources(machine: Machine, ssh: Remote, force_fetch_local: bool) -> st
|
|||||||
@API.register
|
@API.register
|
||||||
def run_machine_update(
|
def run_machine_update(
|
||||||
machine: Machine,
|
machine: Machine,
|
||||||
target_host: Remote,
|
target_host: Host,
|
||||||
build_host: Remote | None,
|
build_host: Host | None,
|
||||||
force_fetch_local: bool = False,
|
force_fetch_local: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update an existing machine using nixos-rebuild or darwin-rebuild.
|
"""Update an existing machine using nixos-rebuild or darwin-rebuild.
|
||||||
@@ -126,13 +127,13 @@ def run_machine_update(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
with ExitStack() as stack:
|
with ExitStack() as stack:
|
||||||
target_host = stack.enter_context(target_host.ssh_control_master())
|
target_host = stack.enter_context(target_host.host_connection())
|
||||||
|
|
||||||
# If no build host is specified, use the target host as the build host.
|
# If no build host is specified, use the target host as the build host.
|
||||||
if build_host is None:
|
if build_host is None:
|
||||||
build_host = target_host
|
build_host = target_host
|
||||||
else:
|
else:
|
||||||
build_host = stack.enter_context(build_host.ssh_control_master())
|
stack.enter_context(build_host.host_connection())
|
||||||
|
|
||||||
# Some operations require root privileges on the target host.
|
# Some operations require root privileges on the target host.
|
||||||
target_host_root = stack.enter_context(target_host.become_root())
|
target_host_root = stack.enter_context(target_host.become_root())
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def check_machine_ssh_login(
|
|||||||
if opts is None:
|
if opts is None:
|
||||||
opts = ConnectionOptions()
|
opts = ConnectionOptions()
|
||||||
|
|
||||||
with remote.ssh_control_master() as ssh:
|
with remote.host_connection() as ssh:
|
||||||
try:
|
try:
|
||||||
ssh.run(
|
ssh.run(
|
||||||
["true"],
|
["true"],
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from clan_lib.cmd import CmdOut, RunOpts, run
|
|||||||
from clan_lib.colors import AnsiColor
|
from clan_lib.colors import AnsiColor
|
||||||
from clan_lib.errors import ClanError, indent_command # Assuming these are available
|
from clan_lib.errors import ClanError, indent_command # Assuming these are available
|
||||||
from clan_lib.nix import nix_shell
|
from clan_lib.nix import nix_shell
|
||||||
|
from clan_lib.ssh.host import Host
|
||||||
from clan_lib.ssh.host_key import HostKeyCheck, hostkey_to_ssh_opts
|
from clan_lib.ssh.host_key import HostKeyCheck, hostkey_to_ssh_opts
|
||||||
from clan_lib.ssh.parse import parse_ssh_uri
|
from clan_lib.ssh.parse import parse_ssh_uri
|
||||||
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
|
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
|
||||||
@@ -30,7 +31,7 @@ NO_OUTPUT_TIMEOUT = 20
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Remote:
|
class Remote(Host):
|
||||||
address: str
|
address: str
|
||||||
command_prefix: str
|
command_prefix: str
|
||||||
user: str = "root"
|
user: str = "root"
|
||||||
@@ -136,7 +137,7 @@ class Remote:
|
|||||||
return run(cmd, opts)
|
return run(cmd, opts)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def ssh_control_master(self) -> Iterator["Remote"]:
|
def host_connection(self) -> Iterator["Remote"]:
|
||||||
"""
|
"""
|
||||||
Context manager to manage SSH ControlMaster connections.
|
Context manager to manage SSH ControlMaster connections.
|
||||||
This will create a temporary directory for the control socket.
|
This will create a temporary directory for the control socket.
|
||||||
@@ -318,11 +319,11 @@ class Remote:
|
|||||||
if env is None:
|
if env is None:
|
||||||
env = {}
|
env = {}
|
||||||
env["NIX_SSHOPTS"] = " ".join(
|
env["NIX_SSHOPTS"] = " ".join(
|
||||||
self.ssh_cmd_opts(control_master=control_master) # Renamed
|
self._ssh_cmd_opts(control_master=control_master) # Renamed
|
||||||
)
|
)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def ssh_cmd_opts(
|
def _ssh_cmd_opts(
|
||||||
self,
|
self,
|
||||||
control_master: bool = True,
|
control_master: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
@@ -373,7 +374,7 @@ class Remote:
|
|||||||
packages.append("sshpass")
|
packages.append("sshpass")
|
||||||
password_args = ["sshpass", "-p", self.password]
|
password_args = ["sshpass", "-p", self.password]
|
||||||
|
|
||||||
current_ssh_opts = self.ssh_cmd_opts(control_master=control_master)
|
current_ssh_opts = self._ssh_cmd_opts(control_master=control_master)
|
||||||
if verbose_ssh or self.verbose_ssh:
|
if verbose_ssh or self.verbose_ssh:
|
||||||
current_ssh_opts.extend(["-v"])
|
current_ssh_opts.extend(["-v"])
|
||||||
if tty:
|
if tty:
|
||||||
@@ -396,7 +397,7 @@ class Remote:
|
|||||||
]
|
]
|
||||||
return nix_shell(packages, cmd)
|
return nix_shell(packages, cmd)
|
||||||
|
|
||||||
def check_sshpass_errorcode(self, res: subprocess.CompletedProcess) -> None:
|
def _check_sshpass_errorcode(self, res: subprocess.CompletedProcess) -> None:
|
||||||
"""
|
"""
|
||||||
Check the return code of the sshpass command and raise an error if it indicates a failure.
|
Check the return code of the sshpass command and raise an error if it indicates a failure.
|
||||||
Error codes are based on man sshpass(1) and may vary by version.
|
Error codes are based on man sshpass(1) and may vary by version.
|
||||||
@@ -454,7 +455,7 @@ class Remote:
|
|||||||
# We only check the error code if a password is set, as sshpass is used.
|
# We only check the error code if a password is set, as sshpass is used.
|
||||||
# AS sshpass swallows all output.
|
# AS sshpass swallows all output.
|
||||||
if self.password:
|
if self.password:
|
||||||
self.check_sshpass_errorcode(res)
|
self._check_sshpass_errorcode(res)
|
||||||
|
|
||||||
def check_machine_ssh_reachable(
|
def check_machine_ssh_reachable(
|
||||||
self, opts: "ConnectionOptions | None" = None
|
self, opts: "ConnectionOptions | None" = None
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ def test_run_no_shell(hosts: list[Remote], runtime: AsyncRuntime) -> None:
|
|||||||
|
|
||||||
def test_sudo_ask_proxy(hosts: list[Remote]) -> None:
|
def test_sudo_ask_proxy(hosts: list[Remote]) -> None:
|
||||||
host = hosts[0]
|
host = hosts[0]
|
||||||
with host.ssh_control_master() as host:
|
with host.host_connection() as host:
|
||||||
proxy = SudoAskpassProxy(host, prompt_command=["bash", "-c", "echo yes"])
|
proxy = SudoAskpassProxy(host, prompt_command=["bash", "-c", "echo yes"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -197,7 +197,7 @@ def test_sudo_ask_proxy(hosts: list[Remote]) -> None:
|
|||||||
|
|
||||||
def test_run_function(hosts: list[Remote], runtime: AsyncRuntime) -> None:
|
def test_run_function(hosts: list[Remote], runtime: AsyncRuntime) -> None:
|
||||||
def some_func(h: Remote) -> bool:
|
def some_func(h: Remote) -> bool:
|
||||||
with h.ssh_control_master() as ssh:
|
with h.host_connection() as ssh:
|
||||||
p = ssh.run(["echo", "hello"])
|
p = ssh.run(["echo", "hello"])
|
||||||
return p.stdout == "hello\n"
|
return p.stdout == "hello\n"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user