Merge pull request 'try{300,301,400}: fix' (#4984) from checkout-update into main

Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/4984
This commit is contained in:
Mic92
2025-08-26 14:31:57 +00:00
64 changed files with 305 additions and 327 deletions

View File

@@ -148,8 +148,8 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
self.send_header("Content-Type", content_type)
self.end_headers()
self.wfile.write(file_data)
except (OSError, json.JSONDecodeError, UnicodeDecodeError) as e:
log.error(f"Error reading Swagger file: {e!s}")
except (OSError, json.JSONDecodeError, UnicodeDecodeError):
log.exception("Error reading Swagger file")
self.send_error(500, "Internal Server Error")
def _get_swagger_file_path(self, rel_path: str) -> Path:
@@ -191,13 +191,13 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
return file_data
def do_OPTIONS(self) -> None: # noqa: N802
def do_OPTIONS(self) -> None:
"""Handle CORS preflight requests."""
self.send_response_only(200)
self._send_cors_headers()
self.end_headers()
def do_GET(self) -> None: # noqa: N802
def do_GET(self) -> None:
"""Handle GET requests."""
parsed_url = urlparse(self.path)
path = parsed_url.path
@@ -211,7 +211,7 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
else:
self.send_api_error_response("info", "Not Found", ["http_bridge", "GET"])
def do_POST(self) -> None: # noqa: N802
def do_POST(self) -> None:
"""Handle POST requests."""
parsed_url = urlparse(self.path)
path = parsed_url.path
@@ -264,10 +264,10 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
"""Read and parse the request body. Returns None if there was an error."""
try:
content_length = int(self.headers.get("Content-Length", 0))
if content_length > 0:
body = self.rfile.read(content_length)
return json.loads(body.decode("utf-8"))
return {}
if content_length == 0:
return {}
body = self.rfile.read(content_length)
return json.loads(body.decode("utf-8"))
except json.JSONDecodeError:
self.send_api_error_response(
"post",

View File

@@ -136,14 +136,14 @@ class TestHttpApiServer:
try:
# Test root endpoint
response = urlopen("http://127.0.0.1:8081/") # noqa: S310
response = urlopen("http://127.0.0.1:8081/")
data: dict = json.loads(response.read().decode())
assert data["body"]["status"] == "success"
assert data["body"]["data"]["message"] == "Clan API Server"
assert data["body"]["data"]["version"] == "1.0.0"
# Test methods endpoint
response = urlopen("http://127.0.0.1:8081/api/methods") # noqa: S310
response = urlopen("http://127.0.0.1:8081/api/methods")
data = json.loads(response.read().decode())
assert data["body"]["status"] == "success"
assert "test_method" in data["body"]["data"]["methods"]
@@ -179,7 +179,7 @@ class TestHttpApiServer:
try:
# Test 404 error
res = urlopen("http://127.0.0.1:8081/nonexistent") # noqa: S310
res = urlopen("http://127.0.0.1:8081/nonexistent")
assert res.status == 200
body = json.loads(res.read().decode())["body"]
assert body["status"] == "error"

View File

@@ -539,11 +539,11 @@ def main() -> None:
try:
args.func(args)
except ClanError as e:
except ClanError:
if debug:
log.exception("Exited with error")
else:
log.error("%s", e)
log.exception("Exited with error")
sys.exit(1)
except KeyboardInterrupt as ex:
log.warning("Interrupted by user", exc_info=ex)

View File

@@ -73,8 +73,7 @@ def complete_machines(
if thread.is_alive():
return iter([])
machines_dict = dict.fromkeys(machines, "machine")
return machines_dict
return dict.fromkeys(machines, "machine")
def complete_services_for_machine(
@@ -118,8 +117,7 @@ def complete_services_for_machine(
if thread.is_alive():
return iter([])
services_dict = dict.fromkeys(services, "service")
return services_dict
return dict.fromkeys(services, "service")
def complete_backup_providers_for_machine(
@@ -162,8 +160,7 @@ def complete_backup_providers_for_machine(
if thread.is_alive():
return iter([])
providers_dict = dict.fromkeys(providers, "provider")
return providers_dict
return dict.fromkeys(providers, "provider")
def complete_state_services_for_machine(
@@ -206,8 +203,7 @@ def complete_state_services_for_machine(
if thread.is_alive():
return iter([])
providers_dict = dict.fromkeys(providers, "service")
return providers_dict
return dict.fromkeys(providers, "service")
def complete_secrets(
@@ -225,8 +221,7 @@ def complete_secrets(
secrets = list_secrets(Flake(flake).path)
secrets_dict = dict.fromkeys(secrets, "secret")
return secrets_dict
return dict.fromkeys(secrets, "secret")
def complete_users(
@@ -244,8 +239,7 @@ def complete_users(
users = list_users(Path(flake))
users_dict = dict.fromkeys(users, "user")
return users_dict
return dict.fromkeys(users, "user")
def complete_groups(
@@ -264,8 +258,7 @@ def complete_groups(
groups_list = list_groups(Path(flake))
groups = [group.name for group in groups_list]
groups_dict = dict.fromkeys(groups, "group")
return groups_dict
return dict.fromkeys(groups, "group")
def complete_templates_disko(
@@ -285,8 +278,7 @@ def complete_templates_disko(
disko_template_list = list_all_templates.builtins.get("disko")
if disko_template_list:
disko_templates = list(disko_template_list)
disko_dict = dict.fromkeys(disko_templates, "disko")
return disko_dict
return dict.fromkeys(disko_templates, "disko")
return []
@@ -307,8 +299,7 @@ def complete_templates_clan(
clan_template_list = list_all_templates.builtins.get("clan")
if clan_template_list:
clan_templates = list(clan_template_list)
clan_dict = dict.fromkeys(clan_templates, "clan")
return clan_dict
return dict.fromkeys(clan_templates, "clan")
return []
@@ -350,8 +341,7 @@ def complete_vars_for_machine(
except (OSError, PermissionError):
pass
vars_dict = dict.fromkeys(vars_list, "var")
return vars_dict
return dict.fromkeys(vars_list, "var")
def complete_target_host(
@@ -392,8 +382,7 @@ def complete_target_host(
if thread.is_alive():
return iter([])
providers_dict = dict.fromkeys(target_hosts, "target_host")
return providers_dict
return dict.fromkeys(target_hosts, "target_host")
def complete_tags(
@@ -462,8 +451,7 @@ def complete_tags(
if any(thread.is_alive() for thread in threads):
return iter([])
providers_dict = dict.fromkeys(tags, "tag")
return providers_dict
return dict.fromkeys(tags, "tag")
def add_dynamic_completer(

View File

@@ -124,10 +124,8 @@ class SecretStore(SecretStoreBase):
os.umask(0o077)
for service in self.machine.facts_data:
for secret in self.machine.facts_data[service]["secret"]:
if isinstance(secret, dict):
secret_name = secret["name"]
else:
# TODO: drop old format soon
secret_name = secret
secret_name = (
secret["name"] if isinstance(secret, dict) else secret
) # TODO: drop old format soon
(output_dir / secret_name).write_bytes(self.get(service, secret_name))
(output_dir / ".pass_info").write_bytes(self.generate_hash())

View File

@@ -14,6 +14,9 @@ from clan_cli.completions import add_dynamic_completer, complete_machines
log = logging.getLogger(__name__)
# Constants for disk validation
EXPECTED_DISK_VALUES = 2
@dataclass
class FlashOptions:
@@ -44,7 +47,7 @@ class AppendDiskAction(argparse.Action):
if not (
isinstance(values, Sequence)
and not isinstance(values, str)
and len(values) == 2
and len(values) == EXPECTED_DISK_VALUES
):
msg = "Two values must be provided for a 'disk'"
raise ValueError(msg)

View File

@@ -3,15 +3,18 @@ import re
VALID_HOSTNAME = re.compile(r"^[a-z0-9]([-a-z0-9]*[a-z0-9])?$", re.IGNORECASE)
# Maximum hostname/machine name length as per RFC specifications
MAX_HOSTNAME_LENGTH = 63
def validate_hostname(hostname: str) -> bool:
if len(hostname) > 63:
if len(hostname) > MAX_HOSTNAME_LENGTH:
return False
return VALID_HOSTNAME.match(hostname) is not None
def machine_name_type(arg_value: str) -> str:
if len(arg_value) > 63:
if len(arg_value) > MAX_HOSTNAME_LENGTH:
msg = "Machine name must be less than 63 characters long"
raise argparse.ArgumentTypeError(msg)
if not VALID_HOSTNAME.match(arg_value):

View File

@@ -10,6 +10,10 @@ from typing import Any
# Ensure you have a logger set up for logging exceptions
log = logging.getLogger(__name__)
# Constants for path trimming and profiler configuration
MAX_PATH_LEVELS = 4
explanation = """
cProfile Output Columns Explanation:
@@ -86,8 +90,8 @@ class ProfilerStore:
def trim_path_to_three_levels(path: str) -> str:
parts = path.split(os.path.sep)
if len(parts) > 4:
return os.path.sep.join(parts[-4:])
if len(parts) > MAX_PATH_LEVELS:
return os.path.sep.join(parts[-MAX_PATH_LEVELS:])
return path

View File

@@ -31,6 +31,9 @@ from .types import VALID_SECRET_NAME, secret_name_type
log = logging.getLogger(__name__)
# Minimum number of keys required to keep a secret group
MIN_KEYS_FOR_GROUP_REMOVAL = 2
def list_generators_secrets(generators_path: Path) -> list[Path]:
paths: list[Path] = []
@@ -328,7 +331,7 @@ def disallow_member(
keys = collect_keys_for_path(group_folder.parent)
if len(keys) < 2:
if len(keys) < MIN_KEYS_FOR_GROUP_REMOVAL:
msg = f"Cannot remove {name} from {group_folder.parent.name}. No keys left. Use 'clan secrets remove {name}' to remove the secret."
raise ClanError(msg)
target.unlink()

View File

@@ -10,6 +10,9 @@ from .sops import get_public_age_keys
VALID_SECRET_NAME = re.compile(r"^[a-zA-Z0-9._-]+$")
VALID_USER_NAME = re.compile(r"^[a-z_]([a-z0-9_-]{0,31})?$")
# Maximum length for user and group names
MAX_USER_GROUP_NAME_LENGTH = 32
def secret_name_type(arg_value: str) -> str:
if not VALID_SECRET_NAME.match(arg_value):
@@ -45,7 +48,7 @@ def public_or_private_age_key_type(arg_value: str) -> str:
def group_or_user_name_type(what: str) -> Callable[[str], str]:
def name_type(arg_value: str) -> str:
if len(arg_value) > 32:
if len(arg_value) > MAX_USER_GROUP_NAME_LENGTH:
msg = f"{what.capitalize()} name must be less than 32 characters long"
raise argparse.ArgumentTypeError(msg)
if not VALID_USER_NAME.match(arg_value):

View File

@@ -184,11 +184,10 @@ class ClanFlake:
self.clan_modules: list[str] = []
self.temporary_home = temporary_home
self.path = temporary_home / "flake"
if not suppress_tmp_home_warning:
if "/tmp" not in str(os.environ.get("HOME")): # noqa: S108 - Checking if HOME is in temp directory
log.warning(
f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}",
)
if not suppress_tmp_home_warning and "/tmp" not in str(os.environ.get("HOME")): # noqa: S108 - Checking if HOME is in temp directory
log.warning(
f"!! $HOME does not point to a temp directory!! HOME={os.environ['HOME']}",
)
def copy(
self,

View File

@@ -10,7 +10,7 @@ from clan_lib.ssh.remote import Remote
@pytest.fixture
def hosts(sshd: Sshd) -> list[Remote]:
login = pwd.getpwuid(os.getuid()).pw_name
group = [
return [
Remote(
address="127.0.0.1",
port=sshd.port,
@@ -20,5 +20,3 @@ def hosts(sshd: Sshd) -> list[Remote]:
command_prefix="local_test",
),
]
return group

View File

@@ -1,4 +1,3 @@
# ruff: noqa: SLF001
import argparse
import pytest

View File

@@ -14,6 +14,12 @@ log = logging.getLogger(__name__)
# This is for simulating user input in tests.
MOCK_PROMPT_RESPONSE: None = None
# ASCII control character constants
CTRL_D_ASCII = 4 # EOF character
CTRL_C_ASCII = 3 # Interrupt character
DEL_ASCII = 127 # Delete character
BACKSPACE_ASCII = 8 # Backspace character
class PromptType(enum.Enum):
LINE = "line"
@@ -80,14 +86,14 @@ def get_multiline_hidden_input() -> str:
char = sys.stdin.read(1)
# Check for Ctrl-D (ASCII value 4 or EOF)
if not char or ord(char) == 4:
if not char or ord(char) == CTRL_D_ASCII:
# Add last line if not empty
if current_line:
lines.append("".join(current_line))
break
# Check for Ctrl-C (KeyboardInterrupt)
if ord(char) == 3:
if ord(char) == CTRL_C_ASCII:
raise KeyboardInterrupt
# Handle Enter key
@@ -98,7 +104,7 @@ def get_multiline_hidden_input() -> str:
sys.stdout.write("\r\n")
sys.stdout.flush()
# Handle backspace
elif ord(char) == 127 or ord(char) == 8:
elif ord(char) == DEL_ASCII or ord(char) == BACKSPACE_ASCII:
if current_line:
current_line.pop()
# Regular character

View File

@@ -164,13 +164,12 @@ class SecretStore(StoreBase):
msg = f"file {file_name} was not found"
raise ClanError(msg)
if outdated:
msg = (
return (
"The local state of some secret vars is inconsistent and needs to be updated.\n"
f"Run 'clan vars fix {machine}' to apply the necessary changes."
"Problems to fix:\n"
"\n".join(o[2] for o in outdated if o[2])
)
return msg
return None
def _set(

View File

@@ -24,10 +24,7 @@ def set_var(machine: str | Machine, var: str | Var, value: bytes, flake: Flake)
_machine = Machine(name=machine, flake=flake)
else:
_machine = machine
if isinstance(var, str):
_var = get_machine_var(_machine, var)
else:
_var = var
_var = get_machine_var(_machine, var) if isinstance(var, str) else var
paths = _var.set(value)
if paths:
commit_files(

View File

@@ -93,12 +93,7 @@ def get_machine_options() -> str:
system = platform.system().lower()
# Determine accelerator based on OS
if system == "darwin":
# macOS uses Hypervisor.framework
accel = "hvf"
else:
# Linux and others use KVM
accel = "kvm"
accel = "hvf" if system == "darwin" else "kvm"
if arch in ("x86_64", "amd64", "i386", "i686"):
# For x86_64, use q35 for modern PCIe support

View File

@@ -279,8 +279,7 @@ API.register(get_system_file)
param = sig.parameters.get(arg_name)
if param:
param_class = param.annotation
return param_class
return param.annotation
return None

View File

@@ -7,6 +7,13 @@ from clan_lib.nix import nix_shell
from . import API
# Avahi output parsing constants
MIN_NEW_SERVICE_PARTS = (
6 # Minimum parts for new service discovery (+;interface;protocol;name;type;domain)
)
MIN_RESOLVED_SERVICE_PARTS = 9 # Minimum parts for resolved service (=;interface;protocol;name;type;domain;host;ip;port)
TXT_RECORD_INDEX = 9 # Index where TXT record appears in resolved service output
@dataclass
class Host:
@@ -40,7 +47,7 @@ def parse_avahi_output(output: str) -> DNSInfo:
parts = line.split(";")
# New service discovered
# print(parts)
if parts[0] == "+" and len(parts) >= 6:
if parts[0] == "+" and len(parts) >= MIN_NEW_SERVICE_PARTS:
interface, protocol, name, type_, domain = parts[1:6]
name = decode_escapes(name)
@@ -58,7 +65,7 @@ def parse_avahi_output(output: str) -> DNSInfo:
)
# Resolved more data for already discovered services
elif parts[0] == "=" and len(parts) >= 9:
elif parts[0] == "=" and len(parts) >= MIN_RESOLVED_SERVICE_PARTS:
interface, protocol, name, type_, domain, host, ip, port = parts[1:9]
name = decode_escapes(name)
@@ -67,8 +74,10 @@ def parse_avahi_output(output: str) -> DNSInfo:
dns_info.services[name].host = decode_escapes(host)
dns_info.services[name].ip = ip
dns_info.services[name].port = port
if len(parts) > 9:
dns_info.services[name].txt = decode_escapes(parts[9])
if len(parts) > TXT_RECORD_INDEX:
dns_info.services[name].txt = decode_escapes(
parts[TXT_RECORD_INDEX]
)
else:
dns_info.services[name] = Host(
interface=parts[1],
@@ -79,7 +88,9 @@ def parse_avahi_output(output: str) -> DNSInfo:
host=decode_escapes(parts[6]),
ip=parts[7],
port=parts[8],
txt=decode_escapes(parts[9]) if len(parts) > 9 else None,
txt=decode_escapes(parts[TXT_RECORD_INDEX])
if len(parts) > TXT_RECORD_INDEX
else None,
)
return dns_info
@@ -105,9 +116,7 @@ def list_system_services_mdns() -> DNSInfo:
],
)
proc = run(cmd)
data = parse_avahi_output(proc.stdout)
return data
return parse_avahi_output(proc.stdout)
def mdns_command(_args: argparse.Namespace) -> None:

View File

@@ -22,6 +22,11 @@ from typing import (
from clan_lib.api.serde import dataclass_to_dict
# Annotation constants
TUPLE_KEY_VALUE_PAIR_LENGTH = (
2 # Expected length for tuple annotations like ("key", value)
)
class JSchemaTypeError(Exception):
pass
@@ -37,9 +42,7 @@ def inspect_dataclass_fields(t: type) -> dict[TypeVar, type]:
type_params = origin.__parameters__
# Create a map from type parameters to actual type arguments
type_map = dict(zip(type_params, type_args, strict=False))
return type_map
return dict(zip(type_params, type_args, strict=False))
def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[str, Any]:
@@ -65,7 +68,10 @@ def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[st
if isinstance(annotation, dict):
# Assuming annotation is a dict that can directly apply to the schema
schema.update(annotation)
elif isinstance(annotation, tuple) and len(annotation) == 2:
elif (
isinstance(annotation, tuple)
and len(annotation) == TUPLE_KEY_VALUE_PAIR_LENGTH
):
# Assuming a tuple where first element is a keyword (like 'minLength') and the second is the value
schema[annotation[0]] = annotation[1]
elif isinstance(annotation, str):
@@ -138,9 +144,10 @@ def type_to_dict(
if "null" not in pv["type"]:
required.add(pn)
elif pv.get("oneOf") is not None:
if "null" not in [i.get("type") for i in pv.get("oneOf", [])]:
required.add(pn)
elif pv.get("oneOf") is not None and "null" not in [
i.get("type") for i in pv.get("oneOf", [])
]:
required.add(pn)
required_fields = {
f.name

View File

@@ -71,7 +71,7 @@ def create_clan(opts: CreateOptions) -> None:
try:
nix_metadata(str(opts.src_flake))
except ClanError:
log.error(
log.exception(
f"Found a repository, but it is not a valid flake: {opts.src_flake}",
)
log.warning("Setting src_flake to None")

View File

@@ -386,10 +386,7 @@ def run(
else:
stack.enter_context(terminate_process_group(process))
if isinstance(options.input, bytes):
input_bytes = options.input
else:
input_bytes = None
input_bytes = options.input if isinstance(options.input, bytes) else None
stdout_buf, stderr_buf = handle_io(
process,

View File

@@ -5,6 +5,9 @@ ANSI16_MARKER = 300
ANSI256_MARKER = 301
DEFAULT_MARKER = 302
# RGB color constants
RGB_MAX_VALUE = 255 # Maximum value for RGB color components (0-255)
class RgbColor(Enum):
"""A subset of CSS colors with RGB values that work well in Dark and Light mode."""
@@ -107,7 +110,11 @@ def color_code(spec: tuple[int, int, int], base: ColorType) -> str:
val = _join(base.value + 8, 5, green)
elif red == DEFAULT_MARKER:
val = _join(base.value + 9)
elif 0 <= red <= 255 and 0 <= green <= 255 and 0 <= blue <= 255:
elif (
0 <= red <= RGB_MAX_VALUE
and 0 <= green <= RGB_MAX_VALUE
and 0 <= blue <= RGB_MAX_VALUE
):
val = _join(base.value + 8, 2, red, green, blue)
else:
msg = f"Invalid color specification: {spec}"

View File

@@ -82,9 +82,7 @@ class PrefixFormatter(logging.Formatter):
self.hostnames += [hostname]
index = self.hostnames.index(hostname)
coloroffset = (index + self.hostname_color_offset) % len(colorcodes)
colorcode = colorcodes[coloroffset]
return colorcode
return colorcodes[coloroffset]
def get_callers(start: int = 2, end: int = 2) -> list[str]:

View File

@@ -67,13 +67,11 @@ def indent_command(command_list: list[str]) -> str:
arg = command_list[i]
formatted_command.append(shlex.quote(arg))
if i < len(command_list) - 1:
# Check if the current argument is an option
if arg.startswith("-"):
# Indent after the next argument
formatted_command.append(" ")
i += 1
formatted_command.append(shlex.quote(command_list[i]))
if i < len(command_list) - 1 and arg.startswith("-"):
# Indent after the next argument
formatted_command.append(" ")
i += 1
formatted_command.append(shlex.quote(command_list[i]))
if i < len(command_list) - 1:
# Add line continuation only if it's not the last argument

View File

@@ -23,7 +23,7 @@ def substitute_flake_inputs(clan_dir: Path, clan_core_path: Path) -> None:
assert flake_lock.exists(), "flake.lock should exist after flake update"
@pytest.fixture()
@pytest.fixture
def offline_flake_hook(clan_core: Path) -> Callable[[Path], None]:
def patch(clan_dir: Path) -> None:
substitute_flake_inputs(clan_dir, clan_core)

View File

@@ -35,7 +35,7 @@ def offline_template(tmp_path_factory: Any, offline_session_flake_hook: Any) ->
return dst_dir
@pytest.fixture()
@pytest.fixture
def patch_clan_template(monkeypatch: Any, offline_template: Path) -> None:
@contextmanager
def fake_clan_template(
@@ -51,7 +51,7 @@ def patch_clan_template(monkeypatch: Any, offline_template: Path) -> None:
monkeypatch.setattr("clan_lib.clan.create.clan_template", fake_clan_template)
@pytest.fixture()
@pytest.fixture
def clan_flake(
tmp_path: Path,
patch_clan_template: Any, # noqa: ARG001

View File

@@ -238,10 +238,7 @@ def parse_selector(selector: str) -> list[Selector]:
for i in range(len(selector)):
c = selector[i]
if stack == []:
mode = "start"
else:
mode = stack[-1]
mode = "start" if stack == [] else stack[-1]
if mode == "end":
if c == ".":
@@ -385,10 +382,7 @@ class FlakeCacheEntry:
) -> None:
selector: Selector
# if we have no more selectors, it means we select all keys from now one and futher down
if selectors == []:
selector = Selector(type=SelectorType.ALL)
else:
selector = selectors[0]
selector = Selector(type=SelectorType.ALL) if selectors == [] else selectors[0]
# first we find out if we have all subkeys already
@@ -528,10 +522,7 @@ class FlakeCacheEntry:
if isinstance(self.value, str | float | int | None):
return True
if selectors == []:
selector = Selector(type=SelectorType.ALL)
else:
selector = selectors[0]
selector = Selector(type=SelectorType.ALL) if selectors == [] else selectors[0]
# we just fetch all subkeys, so we need to check of we inserted all keys at this level before
if selector.type == SelectorType.ALL:
@@ -539,10 +530,9 @@ class FlakeCacheEntry:
msg = f"Expected dict for ALL selector caching, got {type(self.value)}"
raise ClanError(msg)
if self.fetched_all:
result = all(
return all(
self.value[sel].is_cached(selectors[1:]) for sel in self.value
)
return result
return False
if (
selector.type == SelectorType.SET
@@ -582,10 +572,7 @@ class FlakeCacheEntry:
def select(self, selectors: list[Selector]) -> Any:
selector: Selector
if selectors == []:
selector = Selector(type=SelectorType.ALL)
else:
selector = selectors[0]
selector = Selector(type=SelectorType.ALL) if selectors == [] else selectors[0]
# mirror nix behavior where we return outPath if no further selector is specified
if selectors == [] and isinstance(self.value, dict) and "outPath" in self.value:
@@ -677,14 +664,12 @@ class FlakeCacheEntry:
result_dict: dict[str, Any] = {}
for key in keys_to_select:
value = self.value[key].select(selectors[1:])
if self.value[key].exists:
# Skip empty dicts when the original value is None
if not (
isinstance(value, dict)
and len(value) == 0
and self.value[key].value is None
):
result_dict[key] = value
if self.value[key].exists and not (
isinstance(value, dict)
and len(value) == 0
and self.value[key].value is None
):
result_dict[key] = value
return result_dict
# return a KeyError if we cannot fetch the key
@@ -738,13 +723,12 @@ class FlakeCacheEntry:
exists = json_data.get("exists", True)
fetched_all = json_data.get("fetched_all", False)
entry = FlakeCacheEntry(
return FlakeCacheEntry(
value=value,
is_list=is_list,
exists=exists,
fetched_all=fetched_all,
)
return entry
def __repr__(self) -> str:
if isinstance(self.value, dict):
@@ -760,10 +744,7 @@ class FlakeCache:
self.cache: FlakeCacheEntry = FlakeCacheEntry()
def insert(self, data: dict[str, Any], selector_str: str) -> None:
if selector_str:
selectors = parse_selector(selector_str)
else:
selectors = []
selectors = parse_selector(selector_str) if selector_str else []
self.cache.insert(data, selectors)
@@ -1104,8 +1085,7 @@ class Flake:
else:
log.debug(f"$ clan select {shlex.quote(selector)}")
value = self._cache.select(selector)
return value
return self._cache.select(selector)
def select_machine(self, machine_name: str, selector: str) -> Any:
"""Select a nix attribute for a specific machine.

View File

@@ -45,15 +45,15 @@ def test_cache_persistance(flake: ClanFlake) -> None:
flake2 = Flake(str(flake.path))
flake1.invalidate_cache()
flake2.invalidate_cache()
assert isinstance(flake1._cache, FlakeCache) # noqa: SLF001
assert isinstance(flake2._cache, FlakeCache) # noqa: SLF001
assert not flake1._cache.is_cached( # noqa: SLF001
assert isinstance(flake1._cache, FlakeCache)
assert isinstance(flake2._cache, FlakeCache)
assert not flake1._cache.is_cached(
"nixosConfigurations.*.config.networking.hostName",
)
flake1.select("nixosConfigurations.*.config.networking.hostName")
flake1.select("nixosConfigurations.*.config.networking.{hostName,hostId}")
flake2.invalidate_cache()
assert flake2._cache.is_cached( # noqa: SLF001
assert flake2._cache.is_cached(
"nixosConfigurations.*.config.networking.{hostName,hostId}",
)
@@ -312,10 +312,10 @@ def test_cache_gc(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
my_flake.select("testfile")
else:
my_flake.select("testfile")
assert my_flake._cache is not None # noqa: SLF001
assert my_flake._cache.is_cached("testfile") # noqa: SLF001
assert my_flake._cache is not None
assert my_flake._cache.is_cached("testfile")
subprocess.run(["nix-collect-garbage"], check=True)
assert not my_flake._cache.is_cached("testfile") # noqa: SLF001
assert not my_flake._cache.is_cached("testfile")
def test_store_path_with_line_numbers_not_wrapped() -> None:

View File

@@ -207,8 +207,8 @@ def test_conditional_all_selector(flake: ClanFlake) -> None:
flake2 = Flake(str(flake.path))
flake1.invalidate_cache()
flake2.invalidate_cache()
assert isinstance(flake1._cache, FlakeCache) # noqa: SLF001
assert isinstance(flake2._cache, FlakeCache) # noqa: SLF001
assert isinstance(flake1._cache, FlakeCache)
assert isinstance(flake2._cache, FlakeCache)
log.info("First select")
res1 = flake1.select("inputs.*.{?clan,?missing}.templates.*.*.description")

View File

@@ -8,6 +8,10 @@ from pathlib import Path
log = logging.getLogger(__name__)
# Constants for log parsing
EXPECTED_FILENAME_PARTS = 2 # date_second_str, file_op_key
MIN_PATH_PARTS_FOR_LOGGING = 3 # date/[groups...]/func/file
@dataclass(frozen=True)
class LogGroupConfig:
@@ -559,7 +563,7 @@ class LogManager:
# Parse filename to get op_key and time
filename_stem = log_file_path.stem
parts = filename_stem.split("_", 1)
if len(parts) == 2:
if len(parts) == EXPECTED_FILENAME_PARTS:
date_second_str, file_op_key = parts
if file_op_key == op_key:
@@ -571,7 +575,9 @@ class LogManager:
relative_to_base = log_file_path.relative_to(base_dir)
path_parts = relative_to_base.parts
if len(path_parts) >= 3: # date/[groups...]/func/file
if (
len(path_parts) >= MIN_PATH_PARTS_FOR_LOGGING
): # date/[groups...]/func/file
date_day = path_parts[0]
func_name = path_parts[
-2

View File

@@ -47,9 +47,7 @@ def configured_log_manager(base_dir: Path) -> LogManager:
clans_config = LogGroupConfig("clans", "Clans")
machines_config = LogGroupConfig("machines", "Machines")
clans_config = clans_config.add_child(machines_config)
log_manager = log_manager.add_root_group_config(clans_config)
return log_manager
return log_manager.add_root_group_config(clans_config)
class TestLogGroupConfig:

View File

@@ -212,7 +212,7 @@ def test_get_machine_writeability(clan_flake: Callable[..., Flake]) -> None:
inventory_store.write(inventory, message="Test writeability")
# Check that the tags were updated
persisted = inventory_store._get_persisted() # noqa: SLF001
persisted = inventory_store._get_persisted()
assert get_value_by_path(persisted, "machines.jon.tags", []) == new_tags
write_info = get_machine_fields_schema(Machine("jon", flake))

View File

@@ -133,8 +133,7 @@ class Machine:
remote = get_machine_host(self.name, self.flake, field="buildHost")
if remote:
data = remote.data
return data
return remote.data
return None

View File

@@ -45,7 +45,6 @@ def check_machine_ssh_login(
["true"],
RunOpts(timeout=opts.timeout, needs_user_terminal=True),
)
return
except ClanCmdTimeoutError as e:
msg = f"SSH login timeout after {opts.timeout}s"
raise ClanError(msg) from e
@@ -54,6 +53,8 @@ def check_machine_ssh_login(
raise ClanError(e.cmd.stderr.strip()) from e
msg = f"SSH login failed: {e}"
raise ClanError(msg) from e
else:
return
@API.register

View File

@@ -65,12 +65,11 @@ class Network:
@cached_property
def module(self) -> "NetworkTechnologyBase":
res = import_with_source(
return import_with_source(
self.module_name,
"NetworkTechnology",
NetworkTechnologyBase, # type: ignore[type-abstract]
)
return res
def is_running(self) -> bool:
return self.module.is_running()

View File

@@ -158,15 +158,15 @@ def read_qr_image(image_path: Path) -> dict[str, Any]:
try:
res = run(cmd)
data = res.stdout.strip()
if not data:
msg = f"No QR code found in image: {image_path}"
raise ClanError(msg)
return json.loads(data)
except json.JSONDecodeError as e:
msg = f"Invalid JSON in QR code: {e}"
raise ClanError(msg) from e
except Exception as e:
except OSError as e:
msg = f"Failed to read QR code from {image_path}: {e}"
raise ClanError(msg) from e
if not data:
msg = f"No QR code found in image: {image_path}"
raise ClanError(msg)
return json.loads(data)

View File

@@ -103,8 +103,7 @@ def nix_eval(flags: list[str]) -> list[str]:
def nix_metadata(flake_url: str | Path) -> dict[str, Any]:
cmd = nix_command(["flake", "metadata", "--json", f"{flake_url}"])
proc = run(cmd)
data = json.loads(proc.stdout)
return data
return json.loads(proc.stdout)
# lazy loads list of allowed and static programs

View File

@@ -1,9 +1,6 @@
# DO NOT EDIT THIS FILE MANUALLY. IT IS GENERATED.
# This file was generated by running `pkgs/clan-cli/clan_lib.inventory/update.sh`
#
# ruff: noqa: N815
# ruff: noqa: N806
# ruff: noqa: F401
# fmt: off
from typing import Any, Literal, NotRequired, TypedDict

View File

@@ -151,9 +151,7 @@ class InventoryStore:
)
else:
filtered = cast("InventorySnapshot", raw_value)
sanitized = sanitize(filtered, self._allowed_path_transforms, [])
return sanitized
return sanitize(filtered, self._allowed_path_transforms, [])
def get_readonly_raw(self) -> Inventory:
attrs = "{" + ",".join(self._keys) + "}"

View File

@@ -1,4 +1,3 @@
# ruff: noqa: SLF001
import json
import os
import shutil

View File

@@ -10,6 +10,9 @@ from clan_lib.errors import ClanError
T = TypeVar("T")
# Priority constants for configuration merging
WRITABLE_PRIORITY_THRESHOLD = 100 # Values below this are not writeable
empty: list[str] = []
@@ -138,8 +141,7 @@ def list_difference(all_items: list, filter_items: list) -> list:
def find_duplicates(string_list: list[str]) -> list[str]:
count = Counter(string_list)
duplicates = [item for item, freq in count.items() if freq > 1]
return duplicates
return [item for item, freq in count.items() if freq > 1]
def find_deleted_paths(
@@ -348,7 +350,7 @@ def determine_writeability(
# If priority is less than 100, all children are not writeable
# If the parent passed "non_writeable" earlier, this makes all children not writeable
if (prio is not None and prio < 100) or non_writeable:
if (prio is not None and prio < WRITABLE_PRIORITY_THRESHOLD) or non_writeable:
results["non_writeable"].add(full_key)
if isinstance(value, dict):
determine_writeability(
@@ -370,7 +372,7 @@ def determine_writeability(
raise ClanError(msg)
is_mergeable = False
if prio == 100:
if prio == WRITABLE_PRIORITY_THRESHOLD:
default = defaults.get(key)
if isinstance(default, dict):
is_mergeable = True
@@ -379,7 +381,7 @@ def determine_writeability(
if key_in_correlated:
is_mergeable = True
is_writeable = prio > 100 or is_mergeable
is_writeable = prio > WRITABLE_PRIORITY_THRESHOLD or is_mergeable
# Append the result
if is_writeable:

View File

@@ -10,7 +10,7 @@ from tempfile import NamedTemporaryFile
def create_sandbox_profile() -> str:
"""Create a sandbox profile that allows access to tmpdir and nix store, based on Nix's sandbox-defaults.sb."""
# Based on Nix's sandbox-defaults.sb implementation with TMPDIR parameter
profile_content = """(version 1)
return """(version 1)
(define TMPDIR (param "_TMPDIR"))
@@ -92,8 +92,6 @@ def create_sandbox_profile() -> str:
(allow process-exec (literal "/usr/bin/env"))
"""
return profile_content
@contextmanager
def sandbox_exec_cmd(generator: str, tmpdir: Path) -> Iterator[list[str]]:

View File

@@ -24,8 +24,7 @@ def list_service_instances(flake: Flake) -> InventoryInstancesType:
"""Returns all currently present service instances including their full configuration"""
inventory_store = InventoryStore(flake)
inventory = inventory_store.read()
instances = inventory.get("instances", {})
return instances
return inventory.get("instances", {})
def collect_tags(machines: InventoryMachinesType) -> set[str]:

View File

@@ -20,9 +20,7 @@ def create_secret_key_nixos_anywhere() -> SSHKeyPair:
"""
private_key_dir = user_nixos_anywhere_dir()
key_pair = generate_ssh_key(private_key_dir)
return key_pair
return generate_ssh_key(private_key_dir)
def generate_ssh_key(root_dir: Path) -> SSHKeyPair:

View File

@@ -22,6 +22,9 @@ from clan_lib.ssh.host_key import HostKeyCheck, hostkey_to_ssh_opts
from clan_lib.ssh.socks_wrapper import SocksWrapper
from clan_lib.ssh.sudo_askpass_proxy import SudoAskpassProxy
# Constants for URL parsing
EXPECTED_URL_PARTS = 2 # Expected parts when splitting on '?' or '='
if TYPE_CHECKING:
from clan_lib.network.check import ConnectionOptions
@@ -483,7 +486,9 @@ def _parse_ssh_uri(
address = address.removeprefix("ssh://")
parts = address.split("?", maxsplit=1)
endpoint, maybe_options = parts if len(parts) == 2 else (parts[0], "")
endpoint, maybe_options = (
parts if len(parts) == EXPECTED_URL_PARTS else (parts[0], "")
)
parts = endpoint.split("@")
match len(parts):
@@ -506,7 +511,7 @@ def _parse_ssh_uri(
if len(o) == 0:
continue
parts = o.split("=", maxsplit=1)
if len(parts) != 2:
if len(parts) != EXPECTED_URL_PARTS:
msg = (
f"Invalid option in host `{address}`: option `{o}` does not have "
f"a value (i.e. expected something like `name=value`)"

View File

@@ -82,14 +82,14 @@ class SudoAskpassProxy:
prompt = stripped_line[len("PASSWORD_REQUESTED:") :].strip()
password = self.handle_password_request(prompt)
if ssh_process.stdin is None:
msg = "SSH process stdin is None"
raise ClanError(msg)
logger.error("SSH process stdin is None")
return
print(password, file=ssh_process.stdin)
ssh_process.stdin.flush()
else:
print(stripped_line)
except (OSError, ClanError) as e:
logger.error(f"Error processing passwords requests output: {e}")
except (OSError, ClanError):
logger.exception("Error processing passwords requests output")
def run(self) -> str:
"""Run the SSH command with password proxying. Returns the askpass script path."""

View File

@@ -6,6 +6,10 @@ from clan_lib.cmd import Log, RunOpts
from clan_lib.errors import ClanError
from clan_lib.ssh.host import Host
# Safety constants for upload paths
MIN_SAFE_DEPTH = 3 # Minimum path depth for safety
MIN_EXCEPTION_DEPTH = 2 # Minimum depth for allowed exceptions
def upload(
host: Host,
@@ -28,11 +32,11 @@ def upload(
depth = len(remote_dest.parts) - 1
# General rule: destination must be at least 3 levels deep for safety.
is_too_shallow = depth < 3
is_too_shallow = depth < MIN_SAFE_DEPTH
# Exceptions: Allow depth 2 if the path starts with /tmp/, /root/, or /etc/.
# This allows destinations like /tmp/mydir or /etc/conf.d, but not /tmp or /etc directly.
is_allowed_exception = depth >= 2 and (
is_allowed_exception = depth >= MIN_EXCEPTION_DEPTH and (
str(remote_dest).startswith("/tmp/") # noqa: S108 - Path validation check
or str(remote_dest).startswith("/root/")
or str(remote_dest).startswith("/etc/")

View File

@@ -51,7 +51,7 @@ def test_list_inventory_tags(clan_flake: Callable[..., Flake]) -> None:
inventory_store.write(inventory, message="Test add tags via API")
# Check that the tags were updated
persisted = inventory_store._get_persisted() # noqa: SLF001
persisted = inventory_store._get_persisted()
assert get_value_by_path(persisted, "machines.jon.tags", []) == new_tags
tags = list_tags(flake)

View File

@@ -89,8 +89,8 @@ def machine_template(
try:
yield dst_machine_dir
except Exception as e:
log.error(f"An error occurred inside the 'machine_template' context: {e}")
except Exception:
log.exception("An error occurred inside the 'machine_template' context")
# Ensure that the directory is removed to avoid half-created machines
# Everything in the with block is considered part of the context
@@ -182,7 +182,7 @@ def clan_template(
try:
post_process(dst_dir)
except Exception as e:
log.error(f"Error during post-processing of clan template: {e}")
log.exception("Error during post-processing of clan template")
log.info(f"Removing left-over directory: {dst_dir}")
shutil.rmtree(dst_dir, ignore_errors=True)
msg = (
@@ -191,8 +191,8 @@ def clan_template(
raise ClanError(msg) from e
try:
yield dst_dir
except Exception as e:
log.error(f"An error occurred inside the 'clan_template' context: {e}")
except Exception:
log.exception("An error occurred inside the 'clan_template' context")
log.info(f"Removing left-over directory: {dst_dir}")
shutil.rmtree(dst_dir, ignore_errors=True)
raise

View File

@@ -6,6 +6,9 @@ from pathlib import Path
from clan_cli.cli import create_parser
# Constants for command line argument validation
EXPECTED_ARGC = 2 # Expected number of command line arguments
hidden_subcommands = ["machine", "b", "f", "m", "se", "st", "va", "net", "network"]
@@ -135,16 +138,14 @@ def indent_next(text: str, indent_size: int = 4) -> str:
"""
indent = " " * indent_size
lines = text.split("\n")
indented_text = lines[0] + ("\n" + indent).join(lines[1:])
return indented_text
return lines[0] + ("\n" + indent).join(lines[1:])
def indent_all(text: str, indent_size: int = 4) -> str:
"""Indent all lines in a string."""
indent = " " * indent_size
lines = text.split("\n")
indented_text = indent + ("\n" + indent).join(lines)
return indented_text
return indent + ("\n" + indent).join(lines)
def get_subcommands(
@@ -382,7 +383,7 @@ def build_command_reference() -> None:
def main() -> None:
if len(sys.argv) != 2:
if len(sys.argv) != EXPECTED_ARGC:
print("Usage: python docs.py <command>")
print("Available commands: reference")
sys.exit(1)

View File

@@ -138,6 +138,4 @@ def spawn(
proc.start()
# Return the process
mp_proc = MPProcess(name=proc_name, proc=proc, out_file=out_file)
return mp_proc
return MPProcess(name=proc_name, proc=proc, out_file=out_file)

View File

@@ -91,8 +91,11 @@ class Core:
core = Core()
# Constants
GTK_VERSION_4 = 4
### from pynicotine.gtkgui.application import GTK_API_VERSION
GTK_API_VERSION = 4
GTK_API_VERSION = GTK_VERSION_4
## from pynicotine.gtkgui.application import GTK_GUI_FOLDER_PATH
GTK_GUI_FOLDER_PATH = "assets"
@@ -899,7 +902,7 @@ class Win32Implementation(BaseImplementation):
def _load_ico_buffer(self, icon_name, icon_size):
ico_buffer = b""
if GTK_API_VERSION >= 4:
if GTK_API_VERSION >= GTK_VERSION_4:
icon = ICON_THEME.lookup_icon(
icon_name,
fallbacks=None,
@@ -1118,14 +1121,17 @@ class Win32Implementation(BaseImplementation):
# Icon pressed
self.activate_callback()
elif l_param in (
self.NIN_BALLOONHIDE,
self.NIN_BALLOONTIMEOUT,
self.NIN_BALLOONUSERCLICK,
elif (
l_param
in (
self.NIN_BALLOONHIDE,
self.NIN_BALLOONTIMEOUT,
self.NIN_BALLOONUSERCLICK,
)
and not config.sections["ui"]["trayicon"]
):
if not config.sections["ui"]["trayicon"]:
# Notification dismissed, but user has disabled tray icon
self._remove_notify_icon()
# Notification dismissed, but user has disabled tray icon
self._remove_notify_icon()
elif msg == self.WM_COMMAND:
# Menu item pressed

View File

@@ -340,11 +340,9 @@ class VMObject(GObject.Object):
# Try to shutdown the VM gracefully using QMP
try:
if self.qmp_wrap is None:
msg = "QMP wrapper is not available"
raise ClanError(msg)
with self.qmp_wrap.qmp_ctx() as qmp:
qmp.command("system_powerdown")
if self.qmp_wrap is not None:
with self.qmp_wrap.qmp_ctx() as qmp:
qmp.command("system_powerdown")
except (ClanError, OSError, ConnectionError) as ex:
log.debug(f"QMP command 'system_powerdown' ignored. Error: {ex}")

View File

@@ -181,8 +181,7 @@ class ClanStore:
if vm_store is None:
return None
vm = vm_store.get(str(machine.name), None)
return vm
return vm_store.get(str(machine.name), None)
def get_running_vms(self) -> list[VMObject]:
return [

View File

@@ -420,9 +420,6 @@ def run_gen(args: argparse.Namespace) -> None:
"""# DO NOT EDIT THIS FILE MANUALLY. IT IS GENERATED.
# This file was generated by running `pkgs/clan-cli/clan_lib.inventory/update.sh`
#
# ruff: noqa: N815
# ruff: noqa: N806
# ruff: noqa: F401
# fmt: off
from typing import Any, Literal, NotRequired, TypedDict\n

View File

@@ -48,8 +48,7 @@ def list_devshells() -> list[str]:
stdout=subprocess.PIPE,
check=True,
)
names = json.loads(flake_show.stdout.decode())
return names
return json.loads(flake_show.stdout.decode())
def print_devshells() -> None:

View File

@@ -8,13 +8,17 @@ from pathlib import Path
ZEROTIER_STATE_DIR = Path("/var/lib/zerotier-one")
# ZeroTier constants
ZEROTIER_NETWORK_ID_LENGTH = 16 # ZeroTier network ID length
HTTP_OK = 200 # HTTP success status code
class ClanError(Exception):
pass
def compute_zerotier_ip(network_id: str, identity: str) -> ipaddress.IPv6Address:
if len(network_id) != 16:
if len(network_id) != ZEROTIER_NETWORK_ID_LENGTH:
msg = f"network_id must be 16 characters long, got {network_id}"
raise ClanError(msg)
try:
@@ -58,9 +62,7 @@ def compute_member_id(ipv6_addr: str) -> str:
node_id_bytes = addr_bytes[10:16]
node_id = int.from_bytes(node_id_bytes, byteorder="big")
member_id = format(node_id, "x").zfill(10)[-10:]
return member_id
return format(node_id, "x").zfill(10)[-10:]
# this is managed by the nixos module
@@ -90,7 +92,7 @@ def allow_member(args: argparse.Namespace) -> None:
{"X-ZT1-AUTH": token},
)
resp = conn.getresponse()
if resp.status != 200:
if resp.status != HTTP_OK:
msg = f"the zerotier daemon returned this error: {resp.status} {resp.reason}"
raise ClanError(msg)
print(resp.status, resp.reason)