enable bug-bear linting rules

This commit is contained in:
Jörg Thalheim
2024-09-02 13:26:07 +02:00
parent b313f2d066
commit 109d1faf9e
33 changed files with 214 additions and 104 deletions

View File

@@ -54,9 +54,9 @@ class CommandFormatter(logging.Formatter):
prefix_color = ansi_color(self.hostname_colorcode(command_prefix))
color_reset = "\x1b[0m"
setattr(record, "color", color)
setattr(record, "prefix_color", prefix_color)
setattr(record, "color_reset", color_reset)
record.color = color
record.prefix_color = prefix_color
record.color_reset = color_reset
return super().format(record)
@@ -144,9 +144,9 @@ class Host:
forward_agent: bool = False,
command_prefix: str | None = None,
host_key_check: HostKeyCheck = HostKeyCheck.STRICT,
meta: dict[str, Any] = {},
meta: dict[str, Any] | None = None,
verbose_ssh: bool = False,
ssh_options: dict[str, str] = {},
ssh_options: dict[str, str] | None = None,
) -> None:
"""
Creates a Host
@@ -158,6 +158,10 @@ class Host:
@verbose_ssh: Enables verbose logging on ssh connections
@meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function`
"""
if ssh_options is None:
ssh_options = {}
if meta is None:
meta = {}
self.host = host
self.user = user
self.port = port
@@ -200,7 +204,9 @@ class Host:
start = time.time()
last_output = time.time()
while len(rlist) != 0:
r, _, _ = select.select(rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT))
readlist, _, _ = select.select(
rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)
)
def print_from(
print_fd: IO[str], print_buf: str, is_err: bool = False
@@ -227,11 +233,11 @@ class Host:
last_output = time.time()
return (last_output, print_buf)
if print_std_fd in r and print_std_fd is not None:
if print_std_fd in readlist and print_std_fd is not None:
(last_output, print_std_buf) = print_from(
print_std_fd, print_std_buf, is_err=False
)
if print_err_fd in r and print_err_fd is not None:
if print_err_fd in readlist and print_err_fd is not None:
(last_output, print_err_buf) = print_from(
print_err_fd, print_err_buf, is_err=True
)
@@ -245,8 +251,8 @@ class Host:
extra=dict(command_prefix=self.command_prefix),
)
def handle_fd(fd: IO[Any] | None) -> str:
if fd and fd in r:
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> str:
if fd and fd in readlist:
read = os.read(fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(fd)
@@ -254,8 +260,8 @@ class Host:
return read.decode("utf-8")
return ""
stdout_buf += handle_fd(stdout)
stderr_buf += handle_fd(stderr)
stdout_buf += handle_fd(stdout, readlist)
stderr_buf += handle_fd(stderr, readlist)
if now - last_output >= timeout:
break
@@ -268,11 +274,13 @@ class Host:
shell: bool,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
) -> subprocess.CompletedProcess[str]:
if extra_env is None:
extra_env = {}
with ExitStack() as stack:
read_std_fd, write_std_fd = (None, None)
read_err_fd, write_err_fd = (None, None)
@@ -354,7 +362,7 @@ class Host:
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
@@ -371,6 +379,8 @@ class Host:
@return subprocess.CompletedProcess result of the command
"""
if extra_env is None:
extra_env = {}
shell = False
if isinstance(cmd, str):
cmd = [cmd]
@@ -397,7 +407,7 @@ class Host:
stdout: FILE = None,
stderr: FILE = None,
become_root: bool = False,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
@@ -418,6 +428,8 @@ class Host:
@return subprocess.CompletedProcess result of the ssh command
"""
if extra_env is None:
extra_env = {}
sudo = ""
if become_root and self.user != "root":
sudo = "sudo -- "
@@ -548,13 +560,15 @@ class HostGroup:
results: Results,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> None:
if extra_env is None:
extra_env = {}
try:
proc = host.run_local(
cmd,
@@ -577,13 +591,15 @@ class HostGroup:
results: Results,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> None:
if extra_env is None:
extra_env = {}
try:
proc = host.run(
cmd,
@@ -622,13 +638,15 @@ class HostGroup:
local: bool = False,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
) -> Results:
if extra_env is None:
extra_env = {}
results: Results = []
threads = []
for host in self.hosts:
@@ -665,7 +683,7 @@ class HostGroup:
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
@@ -682,6 +700,8 @@ class HostGroup:
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
return self._run(
cmd,
stdout=stdout,
@@ -699,7 +719,7 @@ class HostGroup:
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
@@ -715,6 +735,8 @@ class HostGroup:
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
return self._run(
cmd,
local=True,
@@ -761,8 +783,13 @@ class HostGroup:
def parse_deployment_address(
machine_name: str, host: str, forward_agent: bool = True, meta: dict[str, Any] = {}
machine_name: str,
host: str,
forward_agent: bool = True,
meta: dict[str, Any] | None = None,
) -> Host:
if meta is None:
meta = {}
parts = host.split("@")
user: str | None = None
if len(parts) > 1:

View File

@@ -15,9 +15,11 @@ def ssh(
host: str,
user: str = "root",
password: str | None = None,
ssh_args: list[str] = [],
ssh_args: list[str] | None = None,
torify: bool = False,
) -> None:
if ssh_args is None:
ssh_args = []
packages = ["nixpkgs#openssh"]
if torify:
packages.append("nixpkgs#tor")