enable bug-bear linting rules

This commit is contained in:
Jörg Thalheim
2024-09-02 13:26:07 +02:00
parent af4b9cc2d5
commit 35839ef701
33 changed files with 214 additions and 104 deletions

View File

@@ -85,7 +85,7 @@ def register_common_flags(parser: argparse.ArgumentParser) -> None:
has_subparsers = False
for action in parser._actions:
if isinstance(action, argparse._SubParsersAction):
for choice, child_parser in action.choices.items():
for _choice, child_parser in action.choices.items():
has_subparsers = True
register_common_flags(child_parser)
if not has_subparsers:

View File

@@ -27,12 +27,14 @@ def set_admin_service(
base_url: str,
allowed_keys: dict[str, str],
instance_name: str = "admin",
extra_machines: list[str] = [],
extra_machines: list[str] | None = None,
) -> None:
"""
Set the admin service of a clan
Every machine is by default part of the admin service via the 'all' tag
"""
if extra_machines is None:
extra_machines = []
inventory = load_inventory_eval(base_url)
if not allowed_keys:

View File

@@ -55,7 +55,7 @@ def extract_frontmatter(readme_content: str, err_scope: str) -> tuple[Frontmatte
f"Error parsing TOML frontmatter: {e}",
description=f"Invalid TOML frontmatter. {err_scope}",
location="extract_frontmatter",
)
) from e
return Frontmatter(**frontmatter_parsed), remaining_content
@@ -97,12 +97,12 @@ def get_modules(base_path: str) -> dict[str, str]:
try:
proc = run_no_stdout(cmd)
res = proc.stdout.strip()
except ClanCmdError:
except ClanCmdError as e:
raise ClanError(
"clanInternals might not have clanModules attributes",
location=f"list_modules {base_path}",
description="Evaluation failed on clanInternals.clanModules attribute",
)
) from e
modules: dict[str, str] = json.loads(res)
return modules

View File

@@ -121,11 +121,13 @@ JsonValue = str | float | dict[str, Any] | list[Any] | None
def construct_value(
t: type | UnionType, field_value: JsonValue, loc: list[str] = []
t: type | UnionType, field_value: JsonValue, loc: list[str] | None = None
) -> Any:
"""
Construct a field value from a type hint and a field value.
"""
if loc is None:
loc = []
if t is None and field_value:
raise ClanError(f"Expected None but got: {field_value}", location=f"{loc}")
@@ -203,11 +205,15 @@ def construct_value(
raise ClanError(f"Unhandled field type {t} with value {field_value}")
def construct_dataclass(t: type[T], data: dict[str, Any], path: list[str] = []) -> T:
def construct_dataclass(
t: type[T], data: dict[str, Any], path: list[str] | None = None
) -> T:
"""
type t MUST be a dataclass
Dynamically instantiate a data class from a dictionary, handling nested data classes.
"""
if path is None:
path = []
if not is_dataclass(t):
raise ClanError(f"{t.__name__} is not a dataclass")
@@ -253,8 +259,10 @@ def construct_dataclass(t: type[T], data: dict[str, Any], path: list[str] = [])
def from_dict(
t: type | UnionType, data: dict[str, Any] | Any, path: list[str] = []
t: type | UnionType, data: dict[str, Any] | Any, path: list[str] | None = None
) -> Any:
if path is None:
path = []
if is_dataclass(t):
if not isinstance(data, dict):
raise ClanError(f"{data} is not a dict. Expected {t}")

View File

@@ -30,7 +30,7 @@ def inspect_dataclass_fields(t: type) -> dict[TypeVar, type]:
type_params = origin.__parameters__
# Create a map from type parameters to actual type arguments
type_map = dict(zip(type_params, type_args))
type_map = dict(zip(type_params, type_args, strict=False))
return type_map
@@ -67,7 +67,11 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st
return schema
def type_to_dict(t: Any, scope: str = "", type_map: dict[TypeVar, type] = {}) -> dict:
def type_to_dict(
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
) -> dict:
if type_map is None:
type_map = {}
if t is None:
return {"type": "null"}

View File

@@ -31,7 +31,7 @@ def show_clan_meta(uri: str | Path) -> Meta:
"Evaluation failed on meta attribute",
location=f"show_clan {uri}",
description=str(e.cmd),
)
) from e
clan_meta = json.loads(res)

View File

@@ -29,28 +29,28 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]:
stderr_buf = b""
while len(rlist) != 0:
r, _, _ = select.select(rlist, [], [], 0.1)
if len(r) == 0: # timeout in select
readlist, _, _ = select.select(rlist, [], [], 0.1)
if len(readlist) == 0: # timeout in select
if process.poll() is None:
continue
# Process has exited
break
def handle_fd(fd: IO[Any] | None) -> bytes:
if fd and fd in r:
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> bytes:
if fd and fd in readlist:
read = os.read(fd.fileno(), 4096)
if len(read) != 0:
return read
rlist.remove(fd)
return b""
ret = handle_fd(process.stdout)
ret = handle_fd(process.stdout, readlist)
if ret and log in [Log.STDOUT, Log.BOTH]:
sys.stdout.buffer.write(ret)
sys.stdout.flush()
stdout_buf += ret
ret = handle_fd(process.stderr)
ret = handle_fd(process.stderr, readlist)
if ret and log in [Log.STDERR, Log.BOTH]:
sys.stderr.buffer.write(ret)
@@ -103,11 +103,13 @@ def run(
*,
input: bytes | None = None, # noqa: A002
env: dict[str, str] | None = None,
cwd: Path = Path.cwd(),
cwd: Path | None = None,
log: Log = Log.STDERR,
check: bool = True,
error_msg: str | None = None,
) -> CmdOut:
if cwd is None:
cwd = Path.cwd()
if input:
glog.debug(
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
@@ -155,7 +157,7 @@ def run_no_stdout(
cmd: list[str],
*,
env: dict[str, str] | None = None,
cwd: Path = Path.cwd(),
cwd: Path | None = None,
log: Log = Log.STDERR,
check: bool = True,
error_msg: str | None = None,
@@ -164,6 +166,8 @@ def run_no_stdout(
Like run, but automatically suppresses stdout, if not in DEBUG log level.
If in DEBUG log level the stdout of commands will be shown.
"""
if cwd is None:
cwd = Path.cwd()
if logging.getLogger(__name__.split(".")[0]).isEnabledFor(logging.DEBUG):
return run(cmd, env=env, log=log, check=check, error_msg=error_msg)
else:

View File

@@ -44,7 +44,9 @@ def map_type(nix_type: str) -> Any:
# merge two dicts recursively
def merge(a: dict, b: dict, path: list[str] = []) -> dict:
def merge(a: dict, b: dict, path: list[str] | None = None) -> dict:
if path is None:
path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
@@ -98,10 +100,10 @@ def cast(value: Any, input_type: Any, opt_description: str) -> Any:
if len(value) > 1:
raise ClanError(f"Too many values for {opt_description}")
return input_type(value[0])
except ValueError:
except ValueError as e:
raise ClanError(
f"Invalid type for option {opt_description} (expected {input_type.__name__})"
)
) from e
def options_for_machine(

View File

@@ -18,8 +18,10 @@ def machine_schema(
flake_dir: Path,
config: dict[str, Any],
clan_imports: list[str] | None = None,
option_path: list[str] = ["clan"],
option_path: list[str] | None = None,
) -> dict[str, Any]:
if option_path is None:
option_path = ["clan"]
# use nix eval to lib.evalModules .#nixosConfigurations.<machine_name>.options.clan
with NamedTemporaryFile(mode="w", dir=flake_dir) as clan_machine_settings_file:
env = os.environ.copy()

View File

@@ -48,7 +48,7 @@ def list_possible_keymaps() -> list[str]:
keymap_files = []
for root, _, files in os.walk(keymaps_dir):
for _root, _, files in os.walk(keymaps_dir):
for file in files:
if file.endswith(".map.gz"):
# Remove '.map.gz' ending
@@ -93,8 +93,10 @@ def flash_machine(
dry_run: bool,
write_efi_boot_entries: bool,
debug: bool,
extra_args: list[str] = [],
extra_args: list[str] | None = None,
) -> None:
if extra_args is None:
extra_args = []
system_config_nix: dict[str, Any] = {}
if system_config.wifi_settings:
@@ -125,7 +127,9 @@ def flash_machine(
try:
root_keys.append(key_path.read_text())
except OSError as e:
raise ClanError(f"Cannot read SSH public key file: {key_path}: {e}")
raise ClanError(
f"Cannot read SSH public key file: {key_path}: {e}"
) from e
system_config_nix["users"] = {
"users": {"root": {"openssh": {"authorizedKeys": {"keys": root_keys}}}}
}

View File

@@ -114,7 +114,7 @@ def load_inventory_eval(flake_dir: str | Path) -> Inventory:
inventory = from_dict(Inventory, data)
return inventory
except json.JSONDecodeError as e:
raise ClanError(f"Error decoding inventory from flake: {e}")
raise ClanError(f"Error decoding inventory from flake: {e}") from e
def load_inventory_json(
@@ -134,7 +134,7 @@ def load_inventory_json(
inventory = from_dict(Inventory, res)
except json.JSONDecodeError as e:
# Error decoding the inventory file
raise ClanError(f"Error decoding inventory file: {e}")
raise ClanError(f"Error decoding inventory file: {e}") from e
if not inventory_file.exists():
# Copy over the meta from the flake if the inventory is not initialized

View File

@@ -209,7 +209,7 @@ def generate_machine_hardware_info(
"Invalid hardware-configuration.nix file",
description="The hardware-configuration.nix file is invalid. Please check the file and try again.",
location=f"{__name__} {hw_file}",
)
) from e
return HardwareReport(report_type)

View File

@@ -29,8 +29,10 @@ def install_nixos(
debug: bool = False,
password: str | None = None,
no_reboot: bool = False,
extra_args: list[str] = [],
extra_args: list[str] | None = None,
) -> None:
if extra_args is None:
extra_args = []
secret_facts_module = importlib.import_module(machine.secret_facts_module)
log.info(f"installing {machine.name}")
secret_facts_store = secret_facts_module.SecretStore(machine=machine)

View File

@@ -73,7 +73,7 @@ def list_nixos_machines(flake_url: str | Path) -> list[str]:
data = json.loads(res)
return data
except json.JSONDecodeError as e:
raise ClanError(f"Error decoding machines from flake: {e}")
raise ClanError(f"Error decoding machines from flake: {e}") from e
@dataclass

View File

@@ -133,12 +133,14 @@ class Machine:
attr: str,
extra_config: None | dict = None,
impure: bool = False,
nix_options: list[str] = [],
nix_options: list[str] | None = None,
) -> str | Path:
"""
Build the machine and return the path to the result
accepts a secret store and a facts store # TODO
"""
if nix_options is None:
nix_options = []
config = nix_config()
system = config["system"]
@@ -216,12 +218,14 @@ class Machine:
refresh: bool = False,
extra_config: None | dict = None,
impure: bool = False,
nix_options: list[str] = [],
nix_options: list[str] | None = None,
) -> str:
"""
eval a nix attribute of the machine
@attr: the attribute to get
"""
if nix_options is None:
nix_options = []
if attr in self._eval_cache and not refresh and extra_config is None:
return self._eval_cache[attr]
@@ -238,13 +242,15 @@ class Machine:
refresh: bool = False,
extra_config: None | dict = None,
impure: bool = False,
nix_options: list[str] = [],
nix_options: list[str] | None = None,
) -> Path:
"""
build a nix attribute of the machine
@attr: the attribute to get
"""
if nix_options is None:
nix_options = []
if attr in self._build_cache and not refresh and extra_config is None:
return self._build_cache[attr]

View File

@@ -82,7 +82,7 @@ def upload_sources(
except (json.JSONDecodeError, OSError) as e:
raise ClanError(
f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout}"
)
) from e
@API.register
@@ -180,7 +180,7 @@ def update(args: argparse.Namespace) -> None:
if machine.deployment.get("requireExplicitUpdate", False):
continue
try:
machine.build_host
machine.build_host # noqa: B018
except ClanError: # check if we have a build host set
ignored_machines.append(machine)
continue

View File

@@ -138,10 +138,10 @@ class QEMUMonitorProtocol:
self.__sock.settimeout(wait)
try:
ret = self.__json_read(only_event=True)
except TimeoutError:
raise QMPTimeoutError("Timeout waiting for event")
except Exception:
raise QMPConnectError("Error while reading from socket")
except TimeoutError as e:
raise QMPTimeoutError("Timeout waiting for event") from e
except OSError as e:
raise QMPConnectError("Error while reading from socket") from e
if ret is None:
raise QMPConnectError("Error while reading from socket")
self.__sock.settimeout(None)

View File

@@ -38,8 +38,8 @@ def remove_object(path: Path, name: str) -> list[Path]:
try:
shutil.rmtree(path / name)
paths_to_commit.append(path / name)
except FileNotFoundError:
raise ClanError(f"{name} not found in {path}")
except FileNotFoundError as e:
raise ClanError(f"{name} not found in {path}") from e
if not os.listdir(path):
os.rmdir(path)
return paths_to_commit

View File

@@ -22,10 +22,12 @@ def extract_public_key(filepath: Path) -> str:
if line.startswith("# public key:"):
# Extract and return the public key part after the prefix
return line.strip().split(": ")[1]
except FileNotFoundError:
raise ClanError(f"The file at {filepath} was not found.")
except Exception as e:
raise ClanError(f"An error occurred while extracting the public key: {e}")
except FileNotFoundError as e:
raise ClanError(f"The file at {filepath} was not found.") from e
except OSError as e:
raise ClanError(
f"An error occurred while extracting the public key: {e}"
) from e
raise ClanError(f"Could not find the public key in the file at {filepath}.")

View File

@@ -85,11 +85,19 @@ def encrypt_secret(
flake_dir: Path,
secret_path: Path,
value: IO[str] | str | bytes | None,
add_users: list[str] = [],
add_machines: list[str] = [],
add_groups: list[str] = [],
meta: dict = {},
add_users: list[str] | None = None,
add_machines: list[str] | None = None,
add_groups: list[str] | None = None,
meta: dict | None = None,
) -> None:
if meta is None:
meta = {}
if add_groups is None:
add_groups = []
if add_machines is None:
add_machines = []
if add_users is None:
add_users = []
key = ensure_sops_key(flake_dir)
recipient_keys = set([])

View File

@@ -147,8 +147,10 @@ def encrypt_file(
secret_path: Path,
content: IO[str] | str | bytes | None,
pubkeys: list[str],
meta: dict = {},
meta: dict | None = None,
) -> None:
if meta is None:
meta = {}
folder = secret_path.parent
folder.mkdir(parents=True, exist_ok=True)
@@ -225,10 +227,10 @@ def write_key(path: Path, publickey: str, overwrite: bool) -> None:
if not overwrite:
flags |= os.O_EXCL
fd = os.open(path / "key.json", flags)
except FileExistsError:
except FileExistsError as e:
raise ClanError(
f"{path.name} already exists in {path}. Use --force to overwrite."
)
) from e
with os.fdopen(fd, "w") as f:
json.dump({"publickey": publickey, "type": "age"}, f, indent=2)
@@ -238,7 +240,7 @@ def read_key(path: Path) -> str:
try:
key = json.load(f)
except json.JSONDecodeError as e:
raise ClanError(f"Failed to decode {path.name}: {e}")
raise ClanError(f"Failed to decode {path.name}: {e}") from e
if key["type"] != "age":
raise ClanError(
f"{path.name} is not an age key but {key['type']}. This is not supported"

View File

@@ -54,9 +54,9 @@ class CommandFormatter(logging.Formatter):
prefix_color = ansi_color(self.hostname_colorcode(command_prefix))
color_reset = "\x1b[0m"
setattr(record, "color", color)
setattr(record, "prefix_color", prefix_color)
setattr(record, "color_reset", color_reset)
record.color = color
record.prefix_color = prefix_color
record.color_reset = color_reset
return super().format(record)
@@ -144,9 +144,9 @@ class Host:
forward_agent: bool = False,
command_prefix: str | None = None,
host_key_check: HostKeyCheck = HostKeyCheck.STRICT,
meta: dict[str, Any] = {},
meta: dict[str, Any] | None = None,
verbose_ssh: bool = False,
ssh_options: dict[str, str] = {},
ssh_options: dict[str, str] | None = None,
) -> None:
"""
Creates a Host
@@ -158,6 +158,10 @@ class Host:
@verbose_ssh: Enables verbose logging on ssh connections
@meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function`
"""
if ssh_options is None:
ssh_options = {}
if meta is None:
meta = {}
self.host = host
self.user = user
self.port = port
@@ -200,7 +204,9 @@ class Host:
start = time.time()
last_output = time.time()
while len(rlist) != 0:
r, _, _ = select.select(rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT))
readlist, _, _ = select.select(
rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)
)
def print_from(
print_fd: IO[str], print_buf: str, is_err: bool = False
@@ -227,11 +233,11 @@ class Host:
last_output = time.time()
return (last_output, print_buf)
if print_std_fd in r and print_std_fd is not None:
if print_std_fd in readlist and print_std_fd is not None:
(last_output, print_std_buf) = print_from(
print_std_fd, print_std_buf, is_err=False
)
if print_err_fd in r and print_err_fd is not None:
if print_err_fd in readlist and print_err_fd is not None:
(last_output, print_err_buf) = print_from(
print_err_fd, print_err_buf, is_err=True
)
@@ -245,8 +251,8 @@ class Host:
extra=dict(command_prefix=self.command_prefix),
)
def handle_fd(fd: IO[Any] | None) -> str:
if fd and fd in r:
def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> str:
if fd and fd in readlist:
read = os.read(fd.fileno(), 4096)
if len(read) == 0:
rlist.remove(fd)
@@ -254,8 +260,8 @@ class Host:
return read.decode("utf-8")
return ""
stdout_buf += handle_fd(stdout)
stderr_buf += handle_fd(stderr)
stdout_buf += handle_fd(stdout, readlist)
stderr_buf += handle_fd(stderr, readlist)
if now - last_output >= timeout:
break
@@ -268,11 +274,13 @@ class Host:
shell: bool,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
) -> subprocess.CompletedProcess[str]:
if extra_env is None:
extra_env = {}
with ExitStack() as stack:
read_std_fd, write_std_fd = (None, None)
read_err_fd, write_err_fd = (None, None)
@@ -354,7 +362,7 @@ class Host:
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
@@ -371,6 +379,8 @@ class Host:
@return subprocess.CompletedProcess result of the command
"""
if extra_env is None:
extra_env = {}
shell = False
if isinstance(cmd, str):
cmd = [cmd]
@@ -397,7 +407,7 @@ class Host:
stdout: FILE = None,
stderr: FILE = None,
become_root: bool = False,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
@@ -418,6 +428,8 @@ class Host:
@return subprocess.CompletedProcess result of the ssh command
"""
if extra_env is None:
extra_env = {}
sudo = ""
if become_root and self.user != "root":
sudo = "sudo -- "
@@ -548,13 +560,15 @@ class HostGroup:
results: Results,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> None:
if extra_env is None:
extra_env = {}
try:
proc = host.run_local(
cmd,
@@ -577,13 +591,15 @@ class HostGroup:
results: Results,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
timeout: float = math.inf,
tty: bool = False,
) -> None:
if extra_env is None:
extra_env = {}
try:
proc = host.run(
cmd,
@@ -622,13 +638,15 @@ class HostGroup:
local: bool = False,
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
verbose_ssh: bool = False,
tty: bool = False,
) -> Results:
if extra_env is None:
extra_env = {}
results: Results = []
threads = []
for host in self.hosts:
@@ -665,7 +683,7 @@ class HostGroup:
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
verbose_ssh: bool = False,
@@ -682,6 +700,8 @@ class HostGroup:
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
return self._run(
cmd,
stdout=stdout,
@@ -699,7 +719,7 @@ class HostGroup:
cmd: str | list[str],
stdout: FILE = None,
stderr: FILE = None,
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None,
check: bool = True,
timeout: float = math.inf,
@@ -715,6 +735,8 @@ class HostGroup:
@return a lists of tuples containing Host and the result of the command for this Host
"""
if extra_env is None:
extra_env = {}
return self._run(
cmd,
local=True,
@@ -761,8 +783,13 @@ class HostGroup:
def parse_deployment_address(
machine_name: str, host: str, forward_agent: bool = True, meta: dict[str, Any] = {}
machine_name: str,
host: str,
forward_agent: bool = True,
meta: dict[str, Any] | None = None,
) -> Host:
if meta is None:
meta = {}
parts = host.split("@")
user: str | None = None
if len(parts) > 1:

View File

@@ -15,9 +15,11 @@ def ssh(
host: str,
user: str = "root",
password: str | None = None,
ssh_args: list[str] = [],
ssh_args: list[str] | None = None,
torify: bool = False,
) -> None:
if ssh_args is None:
ssh_args = []
packages = ["nixpkgs#openssh"]
if torify:
packages.append("nixpkgs#tor")

View File

@@ -33,12 +33,12 @@ def list_state_folders(machine: str, service: None | str = None) -> None:
try:
proc = run_no_stdout(cmd)
res = proc.stdout.strip()
except ClanCmdError:
except ClanCmdError as e:
raise ClanError(
"Clan might not have meta attributes",
location=f"show_clan {uri}",
description="Evaluation failed on clanInternals.meta attribute",
)
) from e
state = json.loads(res)
if service:

View File

@@ -94,8 +94,10 @@ def qemu_command(
virtiofsd_socket: Path,
qmp_socket_file: Path,
qga_socket_file: Path,
portmap: list[tuple[int, int]] = [],
portmap: list[tuple[int, int]] | None = None,
) -> QemuCommand:
if portmap is None:
portmap = []
kernel_cmdline = [
(Path(nixos_config["toplevel"]) / "kernel-params").read_text(),
f'init={nixos_config["toplevel"]}/init',

View File

@@ -39,9 +39,11 @@ def facts_to_nixos_config(facts: dict[str, dict[str, bytes]]) -> dict:
# TODO move this to the Machines class
def build_vm(
machine: Machine, tmpdir: Path, nix_options: list[str] = []
machine: Machine, tmpdir: Path, nix_options: list[str] | None = None
) -> dict[str, str]:
# TODO pass prompt here for the GTK gui
if nix_options is None:
nix_options = []
secrets_dir = get_secrets(machine, tmpdir)
public_facts_module = importlib.import_module(machine.public_facts_module)
@@ -58,7 +60,7 @@ def build_vm(
vm_data["secrets_dir"] = str(secrets_dir)
return vm_data
except json.JSONDecodeError as e:
raise ClanError(f"Failed to parse vm config: {e}")
raise ClanError(f"Failed to parse vm config: {e}") from e
def get_secrets(
@@ -108,9 +110,13 @@ def run_vm(
*,
cachedir: Path | None = None,
socketdir: Path | None = None,
nix_options: list[str] = [],
portmap: list[tuple[int, int]] = [],
nix_options: list[str] | None = None,
portmap: list[tuple[int, int]] | None = None,
) -> None:
if portmap is None:
portmap = []
if nix_options is None:
nix_options = []
with ExitStack() as stack:
machine = Machine(name=vm.machine_name, flake=vm.flake_url)
log.debug(f"Creating VM for {machine}")

View File

@@ -156,7 +156,7 @@ def get_subcommands(
parser: argparse.ArgumentParser,
to: list[Category],
level: int = 0,
prefix: list[str] = [],
prefix: list[str] | None = None,
) -> tuple[list[Option], list[Option], list[Subcommand]]:
"""
Generate Markdown documentation for an argparse.ArgumentParser instance including its subcommands.
@@ -168,6 +168,8 @@ def get_subcommands(
# Document each argument
# --flake --option --debug, etc.
if prefix is None:
prefix = []
flag_options: list[Option] = []
positional_options: list[Option] = []
subcommands: list[Subcommand] = []

View File

@@ -67,5 +67,16 @@ ignore_missing_imports = true
[tool.ruff]
target-version = "py311"
line-length = 88
lint.select = [ "E", "F", "I", "U", "N", "RUF", "ANN", "A", "TID" ]
lint.select = [
"A",
"ANN",
"B",
"E",
"F",
"I",
"N",
"RUF",
"TID",
"U",
]
lint.ignore = ["E501", "E402", "E731", "ANN101", "ANN401", "A003"]

View File

@@ -17,12 +17,14 @@ class Command:
def run(
self,
command: list[str],
extra_env: dict[str, str] = {},
extra_env: dict[str, str] | None = None,
stdin: _FILE = None,
stdout: _FILE = None,
stderr: _FILE = None,
workdir: Path | None = None,
) -> subprocess.Popen[str]:
if extra_env is None:
extra_env = {}
env = os.environ.copy()
env.update(extra_env)
# We start a new session here so that we can than more reliably kill all childs as well

View File

@@ -53,15 +53,10 @@ class FlakeForTest(NamedTuple):
def generate_flake(
temporary_home: Path,
flake_template: Path,
substitutions: dict[str, str] = {
"__CHANGE_ME__": "_test_vm_persistence",
"git+https://git.clan.lol/clan/clan-core": "path://" + str(CLAN_CORE),
"https://git.clan.lol/clan/clan-core/archive/main.tar.gz": "path://"
+ str(CLAN_CORE),
},
substitutions: dict[str, str] | None = None,
# define the machines directly including their config
machine_configs: dict[str, dict] = {},
inventory: dict[str, dict] = {},
machine_configs: dict[str, dict] | None = None,
inventory: dict[str, dict] | None = None,
) -> FlakeForTest:
"""
Creates a clan flake with the given name.
@@ -82,6 +77,17 @@ def generate_flake(
"""
# copy the template to a new temporary location
if inventory is None:
inventory = {}
if machine_configs is None:
machine_configs = {}
if substitutions is None:
substitutions = {
"__CHANGE_ME__": "_test_vm_persistence",
"git+https://git.clan.lol/clan/clan-core": "path://" + str(CLAN_CORE),
"https://git.clan.lol/clan/clan-core/archive/main.tar.gz": "path://"
+ str(CLAN_CORE),
}
flake = temporary_home / "flake"
shutil.copytree(flake_template, flake)
sp.run(["chmod", "+w", "-R", str(flake)], check=True)
@@ -136,15 +142,19 @@ def create_flake(
flake_template: str | Path,
clan_core_flake: Path | None = None,
# names referring to pre-defined machines from ../machines
machines: list[str] = [],
machines: list[str] | None = None,
# alternatively specify the machines directly including their config
machine_configs: dict[str, dict] = {},
machine_configs: dict[str, dict] | None = None,
remote: bool = False,
) -> Iterator[FlakeForTest]:
"""
Creates a flake with the given name and machines.
The machine names map to the machines in ./test_machines
"""
if machine_configs is None:
machine_configs = {}
if machines is None:
machines = []
if isinstance(flake_template, Path):
template_path = flake_template
else:

View File

@@ -11,7 +11,7 @@ from clan_cli.errors import ClanError
def find_dataclasses_in_directory(
directory: Path, exclude_paths: list[str] = []
directory: Path, exclude_paths: list[str] | None = None
) -> list[tuple[str, str]]:
"""
Find all dataclass classes in all Python files within a nested directory.
@@ -22,6 +22,8 @@ def find_dataclasses_in_directory(
Returns:
List[Tuple[str, str]]: A list of tuples containing the file path and the dataclass name.
"""
if exclude_paths is None:
exclude_paths = []
dataclass_files = []
excludes = [os.path.join(directory, d) for d in exclude_paths]
@@ -144,4 +146,4 @@ Help:
--------------------------------------------------------------------------------
""",
location=__file__,
)
) from e

View File

@@ -27,7 +27,7 @@ def test_timeout() -> None:
except Exception:
pass
else:
assert False, "should have raised TimeoutExpired"
raise AssertionError("should have raised TimeoutExpired")
def test_run_function() -> None:
@@ -45,7 +45,7 @@ def test_run_exception() -> None:
except Exception:
pass
else:
assert False, "should have raised Exception"
raise AssertionError("should have raised Exception")
def test_run_function_exception() -> None:
@@ -57,7 +57,7 @@ def test_run_function_exception() -> None:
except Exception:
pass
else:
assert False, "should have raised Exception"
raise AssertionError("should have raised Exception")
def test_run_local_non_shell() -> None:

View File

@@ -46,7 +46,7 @@ def test_timeout(host_group: HostGroup) -> None:
except Exception:
pass
else:
assert False, "should have raised TimeoutExpired"
raise AssertionError("should have raised TimeoutExpired")
def test_run_exception(host_group: HostGroup) -> None:
@@ -58,7 +58,7 @@ def test_run_exception(host_group: HostGroup) -> None:
except Exception:
pass
else:
assert False, "should have raised Exception"
raise AssertionError("should have raised Exception")
def test_run_function_exception(host_group: HostGroup) -> None:
@@ -70,4 +70,4 @@ def test_run_function_exception(host_group: HostGroup) -> None:
except Exception:
pass
else:
assert False, "should have raised Exception"
raise AssertionError("should have raised Exception")