diff --git a/pkgs/clan-cli/clan_lib/network/network.py b/pkgs/clan-cli/clan_lib/network/network.py index 4d7294ceb..be7d434e7 100644 --- a/pkgs/clan-cli/clan_lib/network/network.py +++ b/pkgs/clan-cli/clan_lib/network/network.py @@ -136,92 +136,123 @@ def networks_from_flake(flake: Flake) -> dict[str, Network]: return networks -@contextmanager -def get_best_remote(machine: "Machine") -> Iterator["Remote"]: - """Context manager that yields the best remote connection for a machine following this priority: - 1. If machine has targetHost in inventory, return a direct connection - 2. Return the highest priority network where machine is reachable - 3. If no network works, try to get targetHost from machine nixos config +class BestRemoteContext: + """Class-based context manager for establishing and maintaining network connections.""" - Args: - machine: Machine instance to connect to + def __init__(self, machine: "Machine") -> None: + self.machine = machine + self._network_ctx: Any = None + self._remote: Remote | None = None - Yields: - Remote object for connecting to the machine + def __enter__(self) -> "Remote": + """Establish the best remote connection for a machine following this priority: + 1. If machine has targetHost in inventory, return a direct connection + 2. Return the highest priority network where machine is reachable + 3. If no network works, try to get targetHost from machine nixos config - Raises: - ClanError: If no connection method works + Returns: + Remote object for connecting to the machine - """ - # Step 1: Check if targetHost is set in inventory - inv_machine = machine.get_inv_machine() - target_host = inv_machine.get("deploy", {}).get("targetHost") + Raises: + ClanError: If no connection method works - if target_host: - log.debug(f"Using targetHost from inventory for {machine.name}: {target_host}") - # Create a direct network with just this machine - remote = Remote.from_ssh_uri(machine_name=machine.name, address=target_host) - yield remote - return + """ + # Step 1: Check if targetHost is set in inventory + inv_machine = self.machine.get_inv_machine() + target_host = inv_machine.get("deploy", {}).get("targetHost") - # Step 2: Try existing networks by priority - try: - networks = networks_from_flake(machine.flake) + if target_host: + log.debug( + f"Using targetHost from inventory for {self.machine.name}: {target_host}" + ) + self._remote = Remote.from_ssh_uri( + machine_name=self.machine.name, address=target_host + ) + return self._remote - sorted_networks = sorted(networks.items(), key=lambda x: -x[1].priority) + # Step 2: Try existing networks by priority + try: + networks = networks_from_flake(self.machine.flake) + sorted_networks = sorted(networks.items(), key=lambda x: -x[1].priority) - for network_name, network in sorted_networks: - if machine.name not in network.peers: - continue + for network_name, network in sorted_networks: + if self.machine.name not in network.peers: + continue - # Check if network is running and machine is reachable - log.debug(f"trying to connect via {network_name}") - if network.is_running(): - try: - ping_time = network.ping(machine.name) - if ping_time is not None: - log.info( - f"Machine {machine.name} reachable via {network_name} network", - ) - yield network.remote(machine.name) - return - except ClanError as e: - log.debug(f"Failed to reach {machine.name} via {network_name}: {e}") - else: - try: - log.debug(f"Establishing connection for network {network_name}") - with network.module.connection(network) as connected_network: - ping_time = connected_network.ping(machine.name) + log.debug(f"trying to connect via {network_name}") + if network.is_running(): + try: + ping_time = network.ping(self.machine.name) if ping_time is not None: log.info( - f"Machine {machine.name} reachable via {network_name} network after connection", + f"Machine {self.machine.name} reachable via {network_name} network", ) - yield connected_network.remote(machine.name) - return - except ClanError as e: - log.debug( - f"Failed to establish connection to {machine.name} via {network_name}: {e}", - ) - except (ImportError, AttributeError, KeyError) as e: - log.debug(f"Failed to use networking modules to determine machines remote: {e}") + self._remote = remote = network.remote(self.machine.name) + return remote + except ClanError as e: + log.debug( + f"Failed to reach {self.machine.name} via {network_name}: {e}" + ) + else: + try: + log.debug(f"Establishing connection for network {network_name}") + # Enter the network context and keep it alive + self._network_ctx = network.module.connection(network) + connected_network = self._network_ctx.__enter__() + ping_time = connected_network.ping(self.machine.name) + if ping_time is not None: + log.info( + f"Machine {self.machine.name} reachable via {network_name} network after connection", + ) + self._remote = remote = connected_network.remote( + self.machine.name + ) + return remote + # Ping failed, clean up this connection attempt + self._network_ctx.__exit__(None, None, None) + self._network_ctx = None + except ClanError as e: + # Clean up failed connection attempt + if self._network_ctx is not None: + self._network_ctx.__exit__(None, None, None) + self._network_ctx = None + log.debug( + f"Failed to establish connection to {self.machine.name} via {network_name}: {e}", + ) + except (ImportError, AttributeError, KeyError) as e: + log.debug( + f"Failed to use networking modules to determine machines remote: {e}" + ) - # Step 3: Try targetHost from machine nixos config - target_host = machine.select('config.clan.core.networking."targetHost"') - if target_host: - log.debug( - f"Using targetHost from machine config for {machine.name}: {target_host}", - ) - # Check if reachable - remote = Remote.from_ssh_uri( - machine_name=machine.name, - address=target_host, - ) - yield remote - return + # Step 3: Try targetHost from machine nixos config + target_host = self.machine.select('config.clan.core.networking."targetHost"') + if target_host: + log.debug( + f"Using targetHost from machine config for {self.machine.name}: {target_host}", + ) + self._remote = Remote.from_ssh_uri( + machine_name=self.machine.name, + address=target_host, + ) + return self._remote - # No connection method found - msg = f"Could not find any way to connect to machine '{machine.name}'. No targetHost configured and machine not reachable via any network." - raise ClanError(msg) + # No connection method found + msg = f"Could not find any way to connect to machine '{self.machine.name}'. No targetHost configured and machine not reachable via any network." + raise ClanError(msg) + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object, + ) -> None: + """Clean up network connection if one was established.""" + if self._network_ctx is not None: + self._network_ctx.__exit__(exc_type, exc_val, exc_tb) + + +def get_best_remote(machine: "Machine") -> BestRemoteContext: + return BestRemoteContext(machine) def get_network_overview(networks: dict[str, Network]) -> dict: