diff --git a/pkgs/clan-cli/clan_cli/ssh/upload.py b/pkgs/clan-cli/clan_cli/ssh/upload.py index f7ed5e74e..ec22df73f 100644 --- a/pkgs/clan-cli/clan_cli/ssh/upload.py +++ b/pkgs/clan-cli/clan_cli/ssh/upload.py @@ -1,21 +1,28 @@ import tarfile from pathlib import Path +from shlex import quote from tempfile import TemporaryDirectory from clan_cli.cmd import Log, RunOpts from clan_cli.cmd import run as run_local +from clan_cli.errors import ClanError from clan_cli.ssh.host import Host def upload( host: Host, - local_src: Path, # must be a directory + local_src: Path, remote_dest: Path, # must be a directory file_user: str = "root", file_group: str = "root", dir_mode: int = 0o700, file_mode: int = 0o400, ) -> None: + # Check if the remote destination is at least 3 directories deep + if len(remote_dest.parts) < 3: + msg = f"The remote destination must be at least 3 directories deep. Got: {remote_dest}. Reason: The directory will be deleted with 'rm -rf'." + raise ClanError(msg) + # Create the tarball from the temporary directory with TemporaryDirectory(prefix="facts-upload-") as tardir: tar_path = Path(tardir) / "upload.tar.gz" @@ -55,64 +62,22 @@ def upload( with local_src.open("rb") as f: tar.addfile(tarinfo, f) - priviledge_escalation = [] - if host.user != "root": - priviledge_escalation = ["sudo", "--"] + sudo = "" + if host.user != "root" and os.environ.get("IN_PYTEST") is None: + sudo = "sudo -- " - if local_src.is_dir(): - cmd = [ - *host.ssh_cmd(), - "--", - *priviledge_escalation, - "bash", - "-c", - 'exec "$@"', - "--", - "rm", - "-r", - str(remote_dest), - "mkdir", - "-m", - f"{dir_mode:o}", - "-p", - str(remote_dest), - "&&", - "tar", - "-C", - str(remote_dest), - "-xzf", - "-", - ] - else: - # For single file, extract to parent directory and ensure correct name - cmd = [ - *host.ssh_cmd(), - "--", - *priviledge_escalation, - "bash", - "-c", - 'exec "$@"', - "--", - "rm", - "-r", - str(remote_dest), - "mkdir", - "-m", - f"{dir_mode:o}", - "-p", - str(remote_dest.parent), - "&&", - "tar", - "-C", - str(remote_dest.parent), - "-xzf", - "-", - ] + cmd = "rm -rf $0 && mkdir -m $1 -p $0 && tar -C $0 -xzf -" # TODO accept `input` to be an IO object instead of bytes so that we don't have to read the tarfile into memory. with tar_path.open("rb") as f: run_local( - cmd, + [ + *host.ssh_cmd(), + "--", + f"{sudo}bash -c {quote(cmd)}", + str(remote_dest), + f"{dir_mode:o}", + ], RunOpts( input=f.read(), log=Log.BOTH, diff --git a/pkgs/clan-cli/default.nix b/pkgs/clan-cli/default.nix index a3214cb2d..101f872aa 100644 --- a/pkgs/clan-cli/default.nix +++ b/pkgs/clan-cli/default.nix @@ -147,6 +147,7 @@ pythonRuntime.pkgs.buildPythonApplication { cd ./src export NIX_STATE_DIR=$TMPDIR/nix IN_NIX_SANDBOX=1 PYTHONWARNINGS=error + export IN_PYTEST=1 # required to prevent concurrent 'nix flake lock' operations export CLAN_TEST_STORE=$TMPDIR/store @@ -198,6 +199,7 @@ pythonRuntime.pkgs.buildPythonApplication { export NIX_STATE_DIR=$TMPDIR/nix export IN_NIX_SANDBOX=1 export PYTHONWARNINGS=error + export IN_PYTEST=1 export CLAN_TEST_STORE=$TMPDIR/store # required to prevent concurrent 'nix flake lock' operations export LOCK_NIX=$TMPDIR/nix_lock diff --git a/pkgs/clan-cli/tests/test_secrets_upload.py b/pkgs/clan-cli/tests/test_secrets_upload.py index 2556fb863..8b3ed7930 100644 --- a/pkgs/clan-cli/tests/test_secrets_upload.py +++ b/pkgs/clan-cli/tests/test_secrets_upload.py @@ -26,6 +26,17 @@ def test_secrets_upload( monkeypatch.chdir(str(flake.path)) monkeypatch.setenv("SOPS_AGE_KEY", age_keys[0].privkey) + sops_dir = flake.path / "facts" + + # the flake defines this path as the location where the sops key should be installed + sops_key = sops_dir / "key.txt" + sops_key2 = sops_dir / "key2.txt" + + # Create old state, which should be cleaned up + sops_dir.mkdir() + sops_key.write_text("OLD STATE") + sops_key2.write_text("OLD STATE2") + cli.run( [ "secrets", @@ -56,8 +67,6 @@ def test_secrets_upload( cli.run(["facts", "upload", "--flake", str(flake_path), "vm1"]) - # the flake defines this path as the location where the sops key should be installed - sops_key = flake.path / "facts" / "key.txt" - assert sops_key.exists() assert sops_key.read_text() == age_keys[0].privkey + assert not sops_key2.exists()