enable bug-bear linting rules

This commit is contained in:
Jörg Thalheim
2024-09-02 13:26:07 +02:00
parent af4b9cc2d5
commit 35839ef701
33 changed files with 214 additions and 104 deletions

View File

@@ -85,7 +85,7 @@ def register_common_flags(parser: argparse.ArgumentParser) -> None:
has_subparsers = False has_subparsers = False
for action in parser._actions: for action in parser._actions:
if isinstance(action, argparse._SubParsersAction): if isinstance(action, argparse._SubParsersAction):
for choice, child_parser in action.choices.items(): for _choice, child_parser in action.choices.items():
has_subparsers = True has_subparsers = True
register_common_flags(child_parser) register_common_flags(child_parser)
if not has_subparsers: if not has_subparsers:

View File

@@ -27,12 +27,14 @@ def set_admin_service(
base_url: str, base_url: str,
allowed_keys: dict[str, str], allowed_keys: dict[str, str],
instance_name: str = "admin", instance_name: str = "admin",
extra_machines: list[str] = [], extra_machines: list[str] | None = None,
) -> None: ) -> None:
""" """
Set the admin service of a clan Set the admin service of a clan
Every machine is by default part of the admin service via the 'all' tag Every machine is by default part of the admin service via the 'all' tag
""" """
if extra_machines is None:
extra_machines = []
inventory = load_inventory_eval(base_url) inventory = load_inventory_eval(base_url)
if not allowed_keys: if not allowed_keys:

View File

@@ -55,7 +55,7 @@ def extract_frontmatter(readme_content: str, err_scope: str) -> tuple[Frontmatte
f"Error parsing TOML frontmatter: {e}", f"Error parsing TOML frontmatter: {e}",
description=f"Invalid TOML frontmatter. {err_scope}", description=f"Invalid TOML frontmatter. {err_scope}",
location="extract_frontmatter", location="extract_frontmatter",
) ) from e
return Frontmatter(**frontmatter_parsed), remaining_content return Frontmatter(**frontmatter_parsed), remaining_content
@@ -97,12 +97,12 @@ def get_modules(base_path: str) -> dict[str, str]:
try: try:
proc = run_no_stdout(cmd) proc = run_no_stdout(cmd)
res = proc.stdout.strip() res = proc.stdout.strip()
except ClanCmdError: except ClanCmdError as e:
raise ClanError( raise ClanError(
"clanInternals might not have clanModules attributes", "clanInternals might not have clanModules attributes",
location=f"list_modules {base_path}", location=f"list_modules {base_path}",
description="Evaluation failed on clanInternals.clanModules attribute", description="Evaluation failed on clanInternals.clanModules attribute",
) ) from e
modules: dict[str, str] = json.loads(res) modules: dict[str, str] = json.loads(res)
return modules return modules

View File

@@ -121,11 +121,13 @@ JsonValue = str | float | dict[str, Any] | list[Any] | None
def construct_value( def construct_value(
t: type | UnionType, field_value: JsonValue, loc: list[str] = [] t: type | UnionType, field_value: JsonValue, loc: list[str] | None = None
) -> Any: ) -> Any:
""" """
Construct a field value from a type hint and a field value. Construct a field value from a type hint and a field value.
""" """
if loc is None:
loc = []
if t is None and field_value: if t is None and field_value:
raise ClanError(f"Expected None but got: {field_value}", location=f"{loc}") raise ClanError(f"Expected None but got: {field_value}", location=f"{loc}")
@@ -203,11 +205,15 @@ def construct_value(
raise ClanError(f"Unhandled field type {t} with value {field_value}") raise ClanError(f"Unhandled field type {t} with value {field_value}")
def construct_dataclass(t: type[T], data: dict[str, Any], path: list[str] = []) -> T: def construct_dataclass(
t: type[T], data: dict[str, Any], path: list[str] | None = None
) -> T:
""" """
type t MUST be a dataclass type t MUST be a dataclass
Dynamically instantiate a data class from a dictionary, handling nested data classes. Dynamically instantiate a data class from a dictionary, handling nested data classes.
""" """
if path is None:
path = []
if not is_dataclass(t): if not is_dataclass(t):
raise ClanError(f"{t.__name__} is not a dataclass") raise ClanError(f"{t.__name__} is not a dataclass")
@@ -253,8 +259,10 @@ def construct_dataclass(t: type[T], data: dict[str, Any], path: list[str] = [])
def from_dict( def from_dict(
t: type | UnionType, data: dict[str, Any] | Any, path: list[str] = [] t: type | UnionType, data: dict[str, Any] | Any, path: list[str] | None = None
) -> Any: ) -> Any:
if path is None:
path = []
if is_dataclass(t): if is_dataclass(t):
if not isinstance(data, dict): if not isinstance(data, dict):
raise ClanError(f"{data} is not a dict. Expected {t}") raise ClanError(f"{data} is not a dict. Expected {t}")

View File

@@ -30,7 +30,7 @@ def inspect_dataclass_fields(t: type) -> dict[TypeVar, type]:
type_params = origin.__parameters__ type_params = origin.__parameters__
# Create a map from type parameters to actual type arguments # Create a map from type parameters to actual type arguments
type_map = dict(zip(type_params, type_args)) type_map = dict(zip(type_params, type_args, strict=False))
return type_map return type_map
@@ -67,7 +67,11 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st
return schema return schema
def type_to_dict(t: Any, scope: str = "", type_map: dict[TypeVar, type] = {}) -> dict: def type_to_dict(
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
) -> dict:
if type_map is None:
type_map = {}
if t is None: if t is None:
return {"type": "null"} return {"type": "null"}

View File

@@ -31,7 +31,7 @@ def show_clan_meta(uri: str | Path) -> Meta:
"Evaluation failed on meta attribute", "Evaluation failed on meta attribute",
location=f"show_clan {uri}", location=f"show_clan {uri}",
description=str(e.cmd), description=str(e.cmd),
) ) from e
clan_meta = json.loads(res) clan_meta = json.loads(res)

View File

@@ -29,28 +29,28 @@ def handle_output(process: subprocess.Popen, log: Log) -> tuple[str, str]:
stderr_buf = b"" stderr_buf = b""
while len(rlist) != 0: while len(rlist) != 0:
r, _, _ = select.select(rlist, [], [], 0.1) readlist, _, _ = select.select(rlist, [], [], 0.1)
if len(r) == 0: # timeout in select if len(readlist) == 0: # timeout in select
if process.poll() is None: if process.poll() is None:
continue continue
# Process has exited # Process has exited
break break
def handle_fd(fd: IO[Any] | None) -> bytes: def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> bytes:
if fd and fd in r: if fd and fd in readlist:
read = os.read(fd.fileno(), 4096) read = os.read(fd.fileno(), 4096)
if len(read) != 0: if len(read) != 0:
return read return read
rlist.remove(fd) rlist.remove(fd)
return b"" return b""
ret = handle_fd(process.stdout) ret = handle_fd(process.stdout, readlist)
if ret and log in [Log.STDOUT, Log.BOTH]: if ret and log in [Log.STDOUT, Log.BOTH]:
sys.stdout.buffer.write(ret) sys.stdout.buffer.write(ret)
sys.stdout.flush() sys.stdout.flush()
stdout_buf += ret stdout_buf += ret
ret = handle_fd(process.stderr) ret = handle_fd(process.stderr, readlist)
if ret and log in [Log.STDERR, Log.BOTH]: if ret and log in [Log.STDERR, Log.BOTH]:
sys.stderr.buffer.write(ret) sys.stderr.buffer.write(ret)
@@ -103,11 +103,13 @@ def run(
*, *,
input: bytes | None = None, # noqa: A002 input: bytes | None = None, # noqa: A002
env: dict[str, str] | None = None, env: dict[str, str] | None = None,
cwd: Path = Path.cwd(), cwd: Path | None = None,
log: Log = Log.STDERR, log: Log = Log.STDERR,
check: bool = True, check: bool = True,
error_msg: str | None = None, error_msg: str | None = None,
) -> CmdOut: ) -> CmdOut:
if cwd is None:
cwd = Path.cwd()
if input: if input:
glog.debug( glog.debug(
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}""" f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
@@ -155,7 +157,7 @@ def run_no_stdout(
cmd: list[str], cmd: list[str],
*, *,
env: dict[str, str] | None = None, env: dict[str, str] | None = None,
cwd: Path = Path.cwd(), cwd: Path | None = None,
log: Log = Log.STDERR, log: Log = Log.STDERR,
check: bool = True, check: bool = True,
error_msg: str | None = None, error_msg: str | None = None,
@@ -164,6 +166,8 @@ def run_no_stdout(
Like run, but automatically suppresses stdout, if not in DEBUG log level. Like run, but automatically suppresses stdout, if not in DEBUG log level.
If in DEBUG log level the stdout of commands will be shown. If in DEBUG log level the stdout of commands will be shown.
""" """
if cwd is None:
cwd = Path.cwd()
if logging.getLogger(__name__.split(".")[0]).isEnabledFor(logging.DEBUG): if logging.getLogger(__name__.split(".")[0]).isEnabledFor(logging.DEBUG):
return run(cmd, env=env, log=log, check=check, error_msg=error_msg) return run(cmd, env=env, log=log, check=check, error_msg=error_msg)
else: else:

View File

@@ -44,7 +44,9 @@ def map_type(nix_type: str) -> Any:
# merge two dicts recursively # merge two dicts recursively
def merge(a: dict, b: dict, path: list[str] = []) -> dict: def merge(a: dict, b: dict, path: list[str] | None = None) -> dict:
if path is None:
path = []
for key in b: for key in b:
if key in a: if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict): if isinstance(a[key], dict) and isinstance(b[key], dict):
@@ -98,10 +100,10 @@ def cast(value: Any, input_type: Any, opt_description: str) -> Any:
if len(value) > 1: if len(value) > 1:
raise ClanError(f"Too many values for {opt_description}") raise ClanError(f"Too many values for {opt_description}")
return input_type(value[0]) return input_type(value[0])
except ValueError: except ValueError as e:
raise ClanError( raise ClanError(
f"Invalid type for option {opt_description} (expected {input_type.__name__})" f"Invalid type for option {opt_description} (expected {input_type.__name__})"
) ) from e
def options_for_machine( def options_for_machine(

View File

@@ -18,8 +18,10 @@ def machine_schema(
flake_dir: Path, flake_dir: Path,
config: dict[str, Any], config: dict[str, Any],
clan_imports: list[str] | None = None, clan_imports: list[str] | None = None,
option_path: list[str] = ["clan"], option_path: list[str] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
if option_path is None:
option_path = ["clan"]
# use nix eval to lib.evalModules .#nixosConfigurations.<machine_name>.options.clan # use nix eval to lib.evalModules .#nixosConfigurations.<machine_name>.options.clan
with NamedTemporaryFile(mode="w", dir=flake_dir) as clan_machine_settings_file: with NamedTemporaryFile(mode="w", dir=flake_dir) as clan_machine_settings_file:
env = os.environ.copy() env = os.environ.copy()

View File

@@ -48,7 +48,7 @@ def list_possible_keymaps() -> list[str]:
keymap_files = [] keymap_files = []
for root, _, files in os.walk(keymaps_dir): for _root, _, files in os.walk(keymaps_dir):
for file in files: for file in files:
if file.endswith(".map.gz"): if file.endswith(".map.gz"):
# Remove '.map.gz' ending # Remove '.map.gz' ending
@@ -93,8 +93,10 @@ def flash_machine(
dry_run: bool, dry_run: bool,
write_efi_boot_entries: bool, write_efi_boot_entries: bool,
debug: bool, debug: bool,
extra_args: list[str] = [], extra_args: list[str] | None = None,
) -> None: ) -> None:
if extra_args is None:
extra_args = []
system_config_nix: dict[str, Any] = {} system_config_nix: dict[str, Any] = {}
if system_config.wifi_settings: if system_config.wifi_settings:
@@ -125,7 +127,9 @@ def flash_machine(
try: try:
root_keys.append(key_path.read_text()) root_keys.append(key_path.read_text())
except OSError as e: except OSError as e:
raise ClanError(f"Cannot read SSH public key file: {key_path}: {e}") raise ClanError(
f"Cannot read SSH public key file: {key_path}: {e}"
) from e
system_config_nix["users"] = { system_config_nix["users"] = {
"users": {"root": {"openssh": {"authorizedKeys": {"keys": root_keys}}}} "users": {"root": {"openssh": {"authorizedKeys": {"keys": root_keys}}}}
} }

View File

@@ -114,7 +114,7 @@ def load_inventory_eval(flake_dir: str | Path) -> Inventory:
inventory = from_dict(Inventory, data) inventory = from_dict(Inventory, data)
return inventory return inventory
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ClanError(f"Error decoding inventory from flake: {e}") raise ClanError(f"Error decoding inventory from flake: {e}") from e
def load_inventory_json( def load_inventory_json(
@@ -134,7 +134,7 @@ def load_inventory_json(
inventory = from_dict(Inventory, res) inventory = from_dict(Inventory, res)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# Error decoding the inventory file # Error decoding the inventory file
raise ClanError(f"Error decoding inventory file: {e}") raise ClanError(f"Error decoding inventory file: {e}") from e
if not inventory_file.exists(): if not inventory_file.exists():
# Copy over the meta from the flake if the inventory is not initialized # Copy over the meta from the flake if the inventory is not initialized

View File

@@ -209,7 +209,7 @@ def generate_machine_hardware_info(
"Invalid hardware-configuration.nix file", "Invalid hardware-configuration.nix file",
description="The hardware-configuration.nix file is invalid. Please check the file and try again.", description="The hardware-configuration.nix file is invalid. Please check the file and try again.",
location=f"{__name__} {hw_file}", location=f"{__name__} {hw_file}",
) ) from e
return HardwareReport(report_type) return HardwareReport(report_type)

View File

@@ -29,8 +29,10 @@ def install_nixos(
debug: bool = False, debug: bool = False,
password: str | None = None, password: str | None = None,
no_reboot: bool = False, no_reboot: bool = False,
extra_args: list[str] = [], extra_args: list[str] | None = None,
) -> None: ) -> None:
if extra_args is None:
extra_args = []
secret_facts_module = importlib.import_module(machine.secret_facts_module) secret_facts_module = importlib.import_module(machine.secret_facts_module)
log.info(f"installing {machine.name}") log.info(f"installing {machine.name}")
secret_facts_store = secret_facts_module.SecretStore(machine=machine) secret_facts_store = secret_facts_module.SecretStore(machine=machine)

View File

@@ -73,7 +73,7 @@ def list_nixos_machines(flake_url: str | Path) -> list[str]:
data = json.loads(res) data = json.loads(res)
return data return data
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ClanError(f"Error decoding machines from flake: {e}") raise ClanError(f"Error decoding machines from flake: {e}") from e
@dataclass @dataclass

View File

@@ -133,12 +133,14 @@ class Machine:
attr: str, attr: str,
extra_config: None | dict = None, extra_config: None | dict = None,
impure: bool = False, impure: bool = False,
nix_options: list[str] = [], nix_options: list[str] | None = None,
) -> str | Path: ) -> str | Path:
""" """
Build the machine and return the path to the result Build the machine and return the path to the result
accepts a secret store and a facts store # TODO accepts a secret store and a facts store # TODO
""" """
if nix_options is None:
nix_options = []
config = nix_config() config = nix_config()
system = config["system"] system = config["system"]
@@ -216,12 +218,14 @@ class Machine:
refresh: bool = False, refresh: bool = False,
extra_config: None | dict = None, extra_config: None | dict = None,
impure: bool = False, impure: bool = False,
nix_options: list[str] = [], nix_options: list[str] | None = None,
) -> str: ) -> str:
""" """
eval a nix attribute of the machine eval a nix attribute of the machine
@attr: the attribute to get @attr: the attribute to get
""" """
if nix_options is None:
nix_options = []
if attr in self._eval_cache and not refresh and extra_config is None: if attr in self._eval_cache and not refresh and extra_config is None:
return self._eval_cache[attr] return self._eval_cache[attr]
@@ -238,13 +242,15 @@ class Machine:
refresh: bool = False, refresh: bool = False,
extra_config: None | dict = None, extra_config: None | dict = None,
impure: bool = False, impure: bool = False,
nix_options: list[str] = [], nix_options: list[str] | None = None,
) -> Path: ) -> Path:
""" """
build a nix attribute of the machine build a nix attribute of the machine
@attr: the attribute to get @attr: the attribute to get
""" """
if nix_options is None:
nix_options = []
if attr in self._build_cache and not refresh and extra_config is None: if attr in self._build_cache and not refresh and extra_config is None:
return self._build_cache[attr] return self._build_cache[attr]

View File

@@ -82,7 +82,7 @@ def upload_sources(
except (json.JSONDecodeError, OSError) as e: except (json.JSONDecodeError, OSError) as e:
raise ClanError( raise ClanError(
f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout}" f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout}"
) ) from e
@API.register @API.register
@@ -180,7 +180,7 @@ def update(args: argparse.Namespace) -> None:
if machine.deployment.get("requireExplicitUpdate", False): if machine.deployment.get("requireExplicitUpdate", False):
continue continue
try: try:
machine.build_host machine.build_host # noqa: B018
except ClanError: # check if we have a build host set except ClanError: # check if we have a build host set
ignored_machines.append(machine) ignored_machines.append(machine)
continue continue

View File

@@ -138,10 +138,10 @@ class QEMUMonitorProtocol:
self.__sock.settimeout(wait) self.__sock.settimeout(wait)
try: try:
ret = self.__json_read(only_event=True) ret = self.__json_read(only_event=True)
except TimeoutError: except TimeoutError as e:
raise QMPTimeoutError("Timeout waiting for event") raise QMPTimeoutError("Timeout waiting for event") from e
except Exception: except OSError as e:
raise QMPConnectError("Error while reading from socket") raise QMPConnectError("Error while reading from socket") from e
if ret is None: if ret is None:
raise QMPConnectError("Error while reading from socket") raise QMPConnectError("Error while reading from socket")
self.__sock.settimeout(None) self.__sock.settimeout(None)

View File

@@ -38,8 +38,8 @@ def remove_object(path: Path, name: str) -> list[Path]:
try: try:
shutil.rmtree(path / name) shutil.rmtree(path / name)
paths_to_commit.append(path / name) paths_to_commit.append(path / name)
except FileNotFoundError: except FileNotFoundError as e:
raise ClanError(f"{name} not found in {path}") raise ClanError(f"{name} not found in {path}") from e
if not os.listdir(path): if not os.listdir(path):
os.rmdir(path) os.rmdir(path)
return paths_to_commit return paths_to_commit

View File

@@ -22,10 +22,12 @@ def extract_public_key(filepath: Path) -> str:
if line.startswith("# public key:"): if line.startswith("# public key:"):
# Extract and return the public key part after the prefix # Extract and return the public key part after the prefix
return line.strip().split(": ")[1] return line.strip().split(": ")[1]
except FileNotFoundError: except FileNotFoundError as e:
raise ClanError(f"The file at {filepath} was not found.") raise ClanError(f"The file at {filepath} was not found.") from e
except Exception as e: except OSError as e:
raise ClanError(f"An error occurred while extracting the public key: {e}") raise ClanError(
f"An error occurred while extracting the public key: {e}"
) from e
raise ClanError(f"Could not find the public key in the file at {filepath}.") raise ClanError(f"Could not find the public key in the file at {filepath}.")

View File

@@ -85,11 +85,19 @@ def encrypt_secret(
flake_dir: Path, flake_dir: Path,
secret_path: Path, secret_path: Path,
value: IO[str] | str | bytes | None, value: IO[str] | str | bytes | None,
add_users: list[str] = [], add_users: list[str] | None = None,
add_machines: list[str] = [], add_machines: list[str] | None = None,
add_groups: list[str] = [], add_groups: list[str] | None = None,
meta: dict = {}, meta: dict | None = None,
) -> None: ) -> None:
if meta is None:
meta = {}
if add_groups is None:
add_groups = []
if add_machines is None:
add_machines = []
if add_users is None:
add_users = []
key = ensure_sops_key(flake_dir) key = ensure_sops_key(flake_dir)
recipient_keys = set([]) recipient_keys = set([])

View File

@@ -147,8 +147,10 @@ def encrypt_file(
secret_path: Path, secret_path: Path,
content: IO[str] | str | bytes | None, content: IO[str] | str | bytes | None,
pubkeys: list[str], pubkeys: list[str],
meta: dict = {}, meta: dict | None = None,
) -> None: ) -> None:
if meta is None:
meta = {}
folder = secret_path.parent folder = secret_path.parent
folder.mkdir(parents=True, exist_ok=True) folder.mkdir(parents=True, exist_ok=True)
@@ -225,10 +227,10 @@ def write_key(path: Path, publickey: str, overwrite: bool) -> None:
if not overwrite: if not overwrite:
flags |= os.O_EXCL flags |= os.O_EXCL
fd = os.open(path / "key.json", flags) fd = os.open(path / "key.json", flags)
except FileExistsError: except FileExistsError as e:
raise ClanError( raise ClanError(
f"{path.name} already exists in {path}. Use --force to overwrite." f"{path.name} already exists in {path}. Use --force to overwrite."
) ) from e
with os.fdopen(fd, "w") as f: with os.fdopen(fd, "w") as f:
json.dump({"publickey": publickey, "type": "age"}, f, indent=2) json.dump({"publickey": publickey, "type": "age"}, f, indent=2)
@@ -238,7 +240,7 @@ def read_key(path: Path) -> str:
try: try:
key = json.load(f) key = json.load(f)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ClanError(f"Failed to decode {path.name}: {e}") raise ClanError(f"Failed to decode {path.name}: {e}") from e
if key["type"] != "age": if key["type"] != "age":
raise ClanError( raise ClanError(
f"{path.name} is not an age key but {key['type']}. This is not supported" f"{path.name} is not an age key but {key['type']}. This is not supported"

View File

@@ -54,9 +54,9 @@ class CommandFormatter(logging.Formatter):
prefix_color = ansi_color(self.hostname_colorcode(command_prefix)) prefix_color = ansi_color(self.hostname_colorcode(command_prefix))
color_reset = "\x1b[0m" color_reset = "\x1b[0m"
setattr(record, "color", color) record.color = color
setattr(record, "prefix_color", prefix_color) record.prefix_color = prefix_color
setattr(record, "color_reset", color_reset) record.color_reset = color_reset
return super().format(record) return super().format(record)
@@ -144,9 +144,9 @@ class Host:
forward_agent: bool = False, forward_agent: bool = False,
command_prefix: str | None = None, command_prefix: str | None = None,
host_key_check: HostKeyCheck = HostKeyCheck.STRICT, host_key_check: HostKeyCheck = HostKeyCheck.STRICT,
meta: dict[str, Any] = {}, meta: dict[str, Any] | None = None,
verbose_ssh: bool = False, verbose_ssh: bool = False,
ssh_options: dict[str, str] = {}, ssh_options: dict[str, str] | None = None,
) -> None: ) -> None:
""" """
Creates a Host Creates a Host
@@ -158,6 +158,10 @@ class Host:
@verbose_ssh: Enables verbose logging on ssh connections @verbose_ssh: Enables verbose logging on ssh connections
@meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function` @meta: meta attributes associated with the host. Those can be accessed in custom functions passed to `run_function`
""" """
if ssh_options is None:
ssh_options = {}
if meta is None:
meta = {}
self.host = host self.host = host
self.user = user self.user = user
self.port = port self.port = port
@@ -200,7 +204,9 @@ class Host:
start = time.time() start = time.time()
last_output = time.time() last_output = time.time()
while len(rlist) != 0: while len(rlist) != 0:
r, _, _ = select.select(rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)) readlist, _, _ = select.select(
rlist, [], [], min(timeout, NO_OUTPUT_TIMEOUT)
)
def print_from( def print_from(
print_fd: IO[str], print_buf: str, is_err: bool = False print_fd: IO[str], print_buf: str, is_err: bool = False
@@ -227,11 +233,11 @@ class Host:
last_output = time.time() last_output = time.time()
return (last_output, print_buf) return (last_output, print_buf)
if print_std_fd in r and print_std_fd is not None: if print_std_fd in readlist and print_std_fd is not None:
(last_output, print_std_buf) = print_from( (last_output, print_std_buf) = print_from(
print_std_fd, print_std_buf, is_err=False print_std_fd, print_std_buf, is_err=False
) )
if print_err_fd in r and print_err_fd is not None: if print_err_fd in readlist and print_err_fd is not None:
(last_output, print_err_buf) = print_from( (last_output, print_err_buf) = print_from(
print_err_fd, print_err_buf, is_err=True print_err_fd, print_err_buf, is_err=True
) )
@@ -245,8 +251,8 @@ class Host:
extra=dict(command_prefix=self.command_prefix), extra=dict(command_prefix=self.command_prefix),
) )
def handle_fd(fd: IO[Any] | None) -> str: def handle_fd(fd: IO[Any] | None, readlist: list[IO[Any]]) -> str:
if fd and fd in r: if fd and fd in readlist:
read = os.read(fd.fileno(), 4096) read = os.read(fd.fileno(), 4096)
if len(read) == 0: if len(read) == 0:
rlist.remove(fd) rlist.remove(fd)
@@ -254,8 +260,8 @@ class Host:
return read.decode("utf-8") return read.decode("utf-8")
return "" return ""
stdout_buf += handle_fd(stdout) stdout_buf += handle_fd(stdout, readlist)
stderr_buf += handle_fd(stderr) stderr_buf += handle_fd(stderr, readlist)
if now - last_output >= timeout: if now - last_output >= timeout:
break break
@@ -268,11 +274,13 @@ class Host:
shell: bool, shell: bool,
stdout: FILE = None, stdout: FILE = None,
stderr: FILE = None, stderr: FILE = None,
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
) -> subprocess.CompletedProcess[str]: ) -> subprocess.CompletedProcess[str]:
if extra_env is None:
extra_env = {}
with ExitStack() as stack: with ExitStack() as stack:
read_std_fd, write_std_fd = (None, None) read_std_fd, write_std_fd = (None, None)
read_err_fd, write_err_fd = (None, None) read_err_fd, write_err_fd = (None, None)
@@ -354,7 +362,7 @@ class Host:
cmd: str | list[str], cmd: str | list[str],
stdout: FILE = None, stdout: FILE = None,
stderr: FILE = None, stderr: FILE = None,
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
@@ -371,6 +379,8 @@ class Host:
@return subprocess.CompletedProcess result of the command @return subprocess.CompletedProcess result of the command
""" """
if extra_env is None:
extra_env = {}
shell = False shell = False
if isinstance(cmd, str): if isinstance(cmd, str):
cmd = [cmd] cmd = [cmd]
@@ -397,7 +407,7 @@ class Host:
stdout: FILE = None, stdout: FILE = None,
stderr: FILE = None, stderr: FILE = None,
become_root: bool = False, become_root: bool = False,
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
@@ -418,6 +428,8 @@ class Host:
@return subprocess.CompletedProcess result of the ssh command @return subprocess.CompletedProcess result of the ssh command
""" """
if extra_env is None:
extra_env = {}
sudo = "" sudo = ""
if become_root and self.user != "root": if become_root and self.user != "root":
sudo = "sudo -- " sudo = "sudo -- "
@@ -548,13 +560,15 @@ class HostGroup:
results: Results, results: Results,
stdout: FILE = None, stdout: FILE = None,
stderr: FILE = None, stderr: FILE = None,
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf, timeout: float = math.inf,
tty: bool = False, tty: bool = False,
) -> None: ) -> None:
if extra_env is None:
extra_env = {}
try: try:
proc = host.run_local( proc = host.run_local(
cmd, cmd,
@@ -577,13 +591,15 @@ class HostGroup:
results: Results, results: Results,
stdout: FILE = None, stdout: FILE = None,
stderr: FILE = None, stderr: FILE = None,
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
timeout: float = math.inf, timeout: float = math.inf,
tty: bool = False, tty: bool = False,
) -> None: ) -> None:
if extra_env is None:
extra_env = {}
try: try:
proc = host.run( proc = host.run(
cmd, cmd,
@@ -622,13 +638,15 @@ class HostGroup:
local: bool = False, local: bool = False,
stdout: FILE = None, stdout: FILE = None,
stderr: FILE = None, stderr: FILE = None,
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
verbose_ssh: bool = False, verbose_ssh: bool = False,
tty: bool = False, tty: bool = False,
) -> Results: ) -> Results:
if extra_env is None:
extra_env = {}
results: Results = [] results: Results = []
threads = [] threads = []
for host in self.hosts: for host in self.hosts:
@@ -665,7 +683,7 @@ class HostGroup:
cmd: str | list[str], cmd: str | list[str],
stdout: FILE = None, stdout: FILE = None,
stderr: FILE = None, stderr: FILE = None,
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
verbose_ssh: bool = False, verbose_ssh: bool = False,
@@ -682,6 +700,8 @@ class HostGroup:
@return a lists of tuples containing Host and the result of the command for this Host @return a lists of tuples containing Host and the result of the command for this Host
""" """
if extra_env is None:
extra_env = {}
return self._run( return self._run(
cmd, cmd,
stdout=stdout, stdout=stdout,
@@ -699,7 +719,7 @@ class HostGroup:
cmd: str | list[str], cmd: str | list[str],
stdout: FILE = None, stdout: FILE = None,
stderr: FILE = None, stderr: FILE = None,
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
cwd: None | str | Path = None, cwd: None | str | Path = None,
check: bool = True, check: bool = True,
timeout: float = math.inf, timeout: float = math.inf,
@@ -715,6 +735,8 @@ class HostGroup:
@return a lists of tuples containing Host and the result of the command for this Host @return a lists of tuples containing Host and the result of the command for this Host
""" """
if extra_env is None:
extra_env = {}
return self._run( return self._run(
cmd, cmd,
local=True, local=True,
@@ -761,8 +783,13 @@ class HostGroup:
def parse_deployment_address( def parse_deployment_address(
machine_name: str, host: str, forward_agent: bool = True, meta: dict[str, Any] = {} machine_name: str,
host: str,
forward_agent: bool = True,
meta: dict[str, Any] | None = None,
) -> Host: ) -> Host:
if meta is None:
meta = {}
parts = host.split("@") parts = host.split("@")
user: str | None = None user: str | None = None
if len(parts) > 1: if len(parts) > 1:

View File

@@ -15,9 +15,11 @@ def ssh(
host: str, host: str,
user: str = "root", user: str = "root",
password: str | None = None, password: str | None = None,
ssh_args: list[str] = [], ssh_args: list[str] | None = None,
torify: bool = False, torify: bool = False,
) -> None: ) -> None:
if ssh_args is None:
ssh_args = []
packages = ["nixpkgs#openssh"] packages = ["nixpkgs#openssh"]
if torify: if torify:
packages.append("nixpkgs#tor") packages.append("nixpkgs#tor")

View File

@@ -33,12 +33,12 @@ def list_state_folders(machine: str, service: None | str = None) -> None:
try: try:
proc = run_no_stdout(cmd) proc = run_no_stdout(cmd)
res = proc.stdout.strip() res = proc.stdout.strip()
except ClanCmdError: except ClanCmdError as e:
raise ClanError( raise ClanError(
"Clan might not have meta attributes", "Clan might not have meta attributes",
location=f"show_clan {uri}", location=f"show_clan {uri}",
description="Evaluation failed on clanInternals.meta attribute", description="Evaluation failed on clanInternals.meta attribute",
) ) from e
state = json.loads(res) state = json.loads(res)
if service: if service:

View File

@@ -94,8 +94,10 @@ def qemu_command(
virtiofsd_socket: Path, virtiofsd_socket: Path,
qmp_socket_file: Path, qmp_socket_file: Path,
qga_socket_file: Path, qga_socket_file: Path,
portmap: list[tuple[int, int]] = [], portmap: list[tuple[int, int]] | None = None,
) -> QemuCommand: ) -> QemuCommand:
if portmap is None:
portmap = []
kernel_cmdline = [ kernel_cmdline = [
(Path(nixos_config["toplevel"]) / "kernel-params").read_text(), (Path(nixos_config["toplevel"]) / "kernel-params").read_text(),
f'init={nixos_config["toplevel"]}/init', f'init={nixos_config["toplevel"]}/init',

View File

@@ -39,9 +39,11 @@ def facts_to_nixos_config(facts: dict[str, dict[str, bytes]]) -> dict:
# TODO move this to the Machines class # TODO move this to the Machines class
def build_vm( def build_vm(
machine: Machine, tmpdir: Path, nix_options: list[str] = [] machine: Machine, tmpdir: Path, nix_options: list[str] | None = None
) -> dict[str, str]: ) -> dict[str, str]:
# TODO pass prompt here for the GTK gui # TODO pass prompt here for the GTK gui
if nix_options is None:
nix_options = []
secrets_dir = get_secrets(machine, tmpdir) secrets_dir = get_secrets(machine, tmpdir)
public_facts_module = importlib.import_module(machine.public_facts_module) public_facts_module = importlib.import_module(machine.public_facts_module)
@@ -58,7 +60,7 @@ def build_vm(
vm_data["secrets_dir"] = str(secrets_dir) vm_data["secrets_dir"] = str(secrets_dir)
return vm_data return vm_data
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ClanError(f"Failed to parse vm config: {e}") raise ClanError(f"Failed to parse vm config: {e}") from e
def get_secrets( def get_secrets(
@@ -108,9 +110,13 @@ def run_vm(
*, *,
cachedir: Path | None = None, cachedir: Path | None = None,
socketdir: Path | None = None, socketdir: Path | None = None,
nix_options: list[str] = [], nix_options: list[str] | None = None,
portmap: list[tuple[int, int]] = [], portmap: list[tuple[int, int]] | None = None,
) -> None: ) -> None:
if portmap is None:
portmap = []
if nix_options is None:
nix_options = []
with ExitStack() as stack: with ExitStack() as stack:
machine = Machine(name=vm.machine_name, flake=vm.flake_url) machine = Machine(name=vm.machine_name, flake=vm.flake_url)
log.debug(f"Creating VM for {machine}") log.debug(f"Creating VM for {machine}")

View File

@@ -156,7 +156,7 @@ def get_subcommands(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser,
to: list[Category], to: list[Category],
level: int = 0, level: int = 0,
prefix: list[str] = [], prefix: list[str] | None = None,
) -> tuple[list[Option], list[Option], list[Subcommand]]: ) -> tuple[list[Option], list[Option], list[Subcommand]]:
""" """
Generate Markdown documentation for an argparse.ArgumentParser instance including its subcommands. Generate Markdown documentation for an argparse.ArgumentParser instance including its subcommands.
@@ -168,6 +168,8 @@ def get_subcommands(
# Document each argument # Document each argument
# --flake --option --debug, etc. # --flake --option --debug, etc.
if prefix is None:
prefix = []
flag_options: list[Option] = [] flag_options: list[Option] = []
positional_options: list[Option] = [] positional_options: list[Option] = []
subcommands: list[Subcommand] = [] subcommands: list[Subcommand] = []

View File

@@ -67,5 +67,16 @@ ignore_missing_imports = true
[tool.ruff] [tool.ruff]
target-version = "py311" target-version = "py311"
line-length = 88 line-length = 88
lint.select = [ "E", "F", "I", "U", "N", "RUF", "ANN", "A", "TID" ] lint.select = [
"A",
"ANN",
"B",
"E",
"F",
"I",
"N",
"RUF",
"TID",
"U",
]
lint.ignore = ["E501", "E402", "E731", "ANN101", "ANN401", "A003"] lint.ignore = ["E501", "E402", "E731", "ANN101", "ANN401", "A003"]

View File

@@ -17,12 +17,14 @@ class Command:
def run( def run(
self, self,
command: list[str], command: list[str],
extra_env: dict[str, str] = {}, extra_env: dict[str, str] | None = None,
stdin: _FILE = None, stdin: _FILE = None,
stdout: _FILE = None, stdout: _FILE = None,
stderr: _FILE = None, stderr: _FILE = None,
workdir: Path | None = None, workdir: Path | None = None,
) -> subprocess.Popen[str]: ) -> subprocess.Popen[str]:
if extra_env is None:
extra_env = {}
env = os.environ.copy() env = os.environ.copy()
env.update(extra_env) env.update(extra_env)
# We start a new session here so that we can than more reliably kill all childs as well # We start a new session here so that we can than more reliably kill all childs as well

View File

@@ -53,15 +53,10 @@ class FlakeForTest(NamedTuple):
def generate_flake( def generate_flake(
temporary_home: Path, temporary_home: Path,
flake_template: Path, flake_template: Path,
substitutions: dict[str, str] = { substitutions: dict[str, str] | None = None,
"__CHANGE_ME__": "_test_vm_persistence",
"git+https://git.clan.lol/clan/clan-core": "path://" + str(CLAN_CORE),
"https://git.clan.lol/clan/clan-core/archive/main.tar.gz": "path://"
+ str(CLAN_CORE),
},
# define the machines directly including their config # define the machines directly including their config
machine_configs: dict[str, dict] = {}, machine_configs: dict[str, dict] | None = None,
inventory: dict[str, dict] = {}, inventory: dict[str, dict] | None = None,
) -> FlakeForTest: ) -> FlakeForTest:
""" """
Creates a clan flake with the given name. Creates a clan flake with the given name.
@@ -82,6 +77,17 @@ def generate_flake(
""" """
# copy the template to a new temporary location # copy the template to a new temporary location
if inventory is None:
inventory = {}
if machine_configs is None:
machine_configs = {}
if substitutions is None:
substitutions = {
"__CHANGE_ME__": "_test_vm_persistence",
"git+https://git.clan.lol/clan/clan-core": "path://" + str(CLAN_CORE),
"https://git.clan.lol/clan/clan-core/archive/main.tar.gz": "path://"
+ str(CLAN_CORE),
}
flake = temporary_home / "flake" flake = temporary_home / "flake"
shutil.copytree(flake_template, flake) shutil.copytree(flake_template, flake)
sp.run(["chmod", "+w", "-R", str(flake)], check=True) sp.run(["chmod", "+w", "-R", str(flake)], check=True)
@@ -136,15 +142,19 @@ def create_flake(
flake_template: str | Path, flake_template: str | Path,
clan_core_flake: Path | None = None, clan_core_flake: Path | None = None,
# names referring to pre-defined machines from ../machines # names referring to pre-defined machines from ../machines
machines: list[str] = [], machines: list[str] | None = None,
# alternatively specify the machines directly including their config # alternatively specify the machines directly including their config
machine_configs: dict[str, dict] = {}, machine_configs: dict[str, dict] | None = None,
remote: bool = False, remote: bool = False,
) -> Iterator[FlakeForTest]: ) -> Iterator[FlakeForTest]:
""" """
Creates a flake with the given name and machines. Creates a flake with the given name and machines.
The machine names map to the machines in ./test_machines The machine names map to the machines in ./test_machines
""" """
if machine_configs is None:
machine_configs = {}
if machines is None:
machines = []
if isinstance(flake_template, Path): if isinstance(flake_template, Path):
template_path = flake_template template_path = flake_template
else: else:

View File

@@ -11,7 +11,7 @@ from clan_cli.errors import ClanError
def find_dataclasses_in_directory( def find_dataclasses_in_directory(
directory: Path, exclude_paths: list[str] = [] directory: Path, exclude_paths: list[str] | None = None
) -> list[tuple[str, str]]: ) -> list[tuple[str, str]]:
""" """
Find all dataclass classes in all Python files within a nested directory. Find all dataclass classes in all Python files within a nested directory.
@@ -22,6 +22,8 @@ def find_dataclasses_in_directory(
Returns: Returns:
List[Tuple[str, str]]: A list of tuples containing the file path and the dataclass name. List[Tuple[str, str]]: A list of tuples containing the file path and the dataclass name.
""" """
if exclude_paths is None:
exclude_paths = []
dataclass_files = [] dataclass_files = []
excludes = [os.path.join(directory, d) for d in exclude_paths] excludes = [os.path.join(directory, d) for d in exclude_paths]
@@ -144,4 +146,4 @@ Help:
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
""", """,
location=__file__, location=__file__,
) ) from e

View File

@@ -27,7 +27,7 @@ def test_timeout() -> None:
except Exception: except Exception:
pass pass
else: else:
assert False, "should have raised TimeoutExpired" raise AssertionError("should have raised TimeoutExpired")
def test_run_function() -> None: def test_run_function() -> None:
@@ -45,7 +45,7 @@ def test_run_exception() -> None:
except Exception: except Exception:
pass pass
else: else:
assert False, "should have raised Exception" raise AssertionError("should have raised Exception")
def test_run_function_exception() -> None: def test_run_function_exception() -> None:
@@ -57,7 +57,7 @@ def test_run_function_exception() -> None:
except Exception: except Exception:
pass pass
else: else:
assert False, "should have raised Exception" raise AssertionError("should have raised Exception")
def test_run_local_non_shell() -> None: def test_run_local_non_shell() -> None:

View File

@@ -46,7 +46,7 @@ def test_timeout(host_group: HostGroup) -> None:
except Exception: except Exception:
pass pass
else: else:
assert False, "should have raised TimeoutExpired" raise AssertionError("should have raised TimeoutExpired")
def test_run_exception(host_group: HostGroup) -> None: def test_run_exception(host_group: HostGroup) -> None:
@@ -58,7 +58,7 @@ def test_run_exception(host_group: HostGroup) -> None:
except Exception: except Exception:
pass pass
else: else:
assert False, "should have raised Exception" raise AssertionError("should have raised Exception")
def test_run_function_exception(host_group: HostGroup) -> None: def test_run_function_exception(host_group: HostGroup) -> None:
@@ -70,4 +70,4 @@ def test_run_function_exception(host_group: HostGroup) -> None:
except Exception: except Exception:
pass pass
else: else:
assert False, "should have raised Exception" raise AssertionError("should have raised Exception")