Merge pull request 'add SLF lint' (#2019) from type-checking into main

This commit is contained in:
clan-bot
2024-09-02 15:35:31 +00:00
31 changed files with 79 additions and 104 deletions

View File

@@ -50,7 +50,7 @@ class ImplFunc(GObject.Object, Generic[P, B]):
msg = "Method 'async_run' must be implemented" msg = "Method 'async_run' must be implemented"
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _async_run(self, data: Any) -> bool: def internal_async_run(self, data: Any) -> bool:
result = GLib.SOURCE_REMOVE result = GLib.SOURCE_REMOVE
try: try:
result = self.async_run(**data) result = self.async_run(**data)

View File

@@ -118,7 +118,7 @@ class WebExecutor(GObject.Object):
# from_dict really takes Anything and returns an instance of the type/class # from_dict really takes Anything and returns an instance of the type/class
reconciled_arguments[k] = from_dict(arg_class, v) reconciled_arguments[k] = from_dict(arg_class, v)
GLib.idle_add(fn_instance._async_run, reconciled_arguments) GLib.idle_add(fn_instance.internal_async_run, reconciled_arguments)
def on_result(self, source: ImplFunc, data: GResult) -> None: def on_result(self, source: ImplFunc, data: GResult) -> None:
result = dataclass_to_dict(data.result) result = dataclass_to_dict(data.result)

View File

@@ -1,3 +1,4 @@
import contextlib
import os import os
import signal import signal
import subprocess import subprocess
@@ -46,10 +47,8 @@ class Command:
# We just kill all processes as quickly as possible because we don't # We just kill all processes as quickly as possible because we don't
# care about corrupted state and want to make tests fasts. # care about corrupted state and want to make tests fasts.
for p in reversed(self.processes): for p in reversed(self.processes):
try: with contextlib.suppress(OSError):
os.killpg(os.getpgid(p.pid), signal.SIGKILL) os.killpg(os.getpgid(p.pid), signal.SIGKILL)
except OSError:
pass
@pytest.fixture @pytest.fixture

View File

@@ -1,4 +1,5 @@
import argparse import argparse
import contextlib
import logging import logging
import sys import sys
from pathlib import Path from pathlib import Path
@@ -35,10 +36,8 @@ from .ssh import cli as ssh_cli
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
argcomplete: ModuleType | None = None argcomplete: ModuleType | None = None
try: with contextlib.suppress(ImportError):
import argcomplete # type: ignore[no-redef] import argcomplete # type: ignore[no-redef]
except ImportError:
pass
def flake_path(arg: str) -> FlakeId: def flake_path(arg: str) -> FlakeId:
@@ -83,8 +82,8 @@ def add_common_flags(parser: argparse.ArgumentParser) -> None:
def register_common_flags(parser: argparse.ArgumentParser) -> None: def register_common_flags(parser: argparse.ArgumentParser) -> None:
has_subparsers = False has_subparsers = False
for action in parser._actions: for action in parser._actions: # noqa: SLF001
if isinstance(action, argparse._SubParsersAction): if isinstance(action, argparse._SubParsersAction): # noqa: SLF001
for _choice, child_parser in action.choices.items(): for _choice, child_parser in action.choices.items():
has_subparsers = True has_subparsers = True
register_common_flags(child_parser) register_common_flags(child_parser)

View File

@@ -47,11 +47,8 @@ class FlakeId:
""" """
x = urllib.parse.urlparse(str(self.loc)) x = urllib.parse.urlparse(str(self.loc))
if x.scheme == "" or "file" in x.scheme: # See above *file* or empty are the only local schemas
# See above *file* or empty are the only local schemas return x.scheme == "" or "file" in x.scheme
return True
return False
def is_remote(self) -> bool: def is_remote(self) -> bool:
return not self.is_local() return not self.is_local()

View File

@@ -1,4 +1,5 @@
import argparse import argparse
import contextlib
import json import json
import subprocess import subprocess
import threading import threading
@@ -17,10 +18,8 @@ We target a maximum of 1second on our average machine.
argcomplete: ModuleType | None = None argcomplete: ModuleType | None = None
try: with contextlib.suppress(ImportError):
import argcomplete # type: ignore[no-redef] import argcomplete # type: ignore[no-redef]
except ImportError:
pass
# The default completion timeout for commands # The default completion timeout for commands
@@ -211,10 +210,7 @@ def complete_secrets(
from .clan_uri import FlakeId from .clan_uri import FlakeId
from .secrets.secrets import ListSecretsOptions, list_secrets from .secrets.secrets import ListSecretsOptions, list_secrets
if (clan_dir_result := clan_dir(None)) is not None: flake = clan_dir_result if (clan_dir_result := clan_dir(None)) is not None else "."
flake = clan_dir_result
else:
flake = "."
options = ListSecretsOptions( options = ListSecretsOptions(
flake=FlakeId(flake), flake=FlakeId(flake),
@@ -237,10 +233,7 @@ def complete_users(
from .secrets.users import list_users from .secrets.users import list_users
if (clan_dir_result := clan_dir(None)) is not None: flake = clan_dir_result if (clan_dir_result := clan_dir(None)) is not None else "."
flake = clan_dir_result
else:
flake = "."
users = list_users(Path(flake)) users = list_users(Path(flake))
@@ -258,10 +251,7 @@ def complete_groups(
from .secrets.groups import list_groups from .secrets.groups import list_groups
if (clan_dir_result := clan_dir(None)) is not None: flake = clan_dir_result if (clan_dir_result := clan_dir(None)) is not None else "."
flake = clan_dir_result
else:
flake = "."
groups_list = list_groups(Path(flake)) groups_list = list_groups(Path(flake))
groups = [group.name for group in groups_list] groups = [group.name for group in groups_list]

View File

@@ -227,7 +227,7 @@ def find_option(
regex = rf"({first}|<name>)" regex = rf"({first}|<name>)"
for elem in option_path[1:]: for elem in option_path[1:]:
regex += rf"\.({elem}|<name>)" regex += rf"\.({elem}|<name>)"
for opt in options.keys(): for opt in options:
if re.match(regex, opt): if re.match(regex, opt):
return opt, value return opt, value

View File

@@ -16,10 +16,7 @@ def check_secrets(machine: Machine, service: None | str = None) -> bool:
missing_secret_facts = [] missing_secret_facts = []
missing_public_facts = [] missing_public_facts = []
if service: services = [service] if service else list(machine.facts_data.keys())
services = [service]
else:
services = list(machine.facts_data.keys())
for service in services: for service in services:
for secret_fact in machine.facts_data[service]["secret"]: for secret_fact in machine.facts_data[service]["secret"]:
if isinstance(secret_fact, str): if isinstance(secret_fact, str):
@@ -41,9 +38,7 @@ def check_secrets(machine: Machine, service: None | str = None) -> bool:
log.debug(f"missing_secret_facts: {missing_secret_facts}") log.debug(f"missing_secret_facts: {missing_secret_facts}")
log.debug(f"missing_public_facts: {missing_public_facts}") log.debug(f"missing_public_facts: {missing_public_facts}")
if missing_secret_facts or missing_public_facts: return not (missing_secret_facts or missing_public_facts)
return False
return True
def check_command(args: argparse.Namespace) -> None: def check_command(args: argparse.Namespace) -> None:

View File

@@ -9,7 +9,7 @@ def list_history_command(args: argparse.Namespace) -> None:
res: dict[str, list[HistoryEntry]] = {} res: dict[str, list[HistoryEntry]] = {}
for history_entry in list_history(): for history_entry in list_history():
url = str(history_entry.flake.flake_url) url = str(history_entry.flake.flake_url)
if res.get(url, None) is None: if res.get(url) is None:
res[url] = [] res[url] = []
res[url].append(history_entry) res[url].append(history_entry)

View File

@@ -12,6 +12,7 @@ Operate on the returned inventory to make changes
- save_inventory: To persist changes. - save_inventory: To persist changes.
""" """
import contextlib
import json import json
from pathlib import Path from pathlib import Path
@@ -163,10 +164,8 @@ def init_inventory(directory: str, init: Inventory | None = None) -> None:
inventory = None inventory = None
# Try reading the current flake # Try reading the current flake
if init is None: if init is None:
try: with contextlib.suppress(ClanCmdError):
inventory = load_inventory_eval(directory) inventory = load_inventory_eval(directory)
except ClanCmdError:
pass
if init is not None: if init is not None:
inventory = init inventory = init

View File

@@ -27,7 +27,7 @@ def create_machine(flake: FlakeId, machine: Machine) -> None:
full_inventory = load_inventory_eval(flake.path) full_inventory = load_inventory_eval(flake.path)
if machine.name in full_inventory.machines.keys(): if machine.name in full_inventory.machines:
msg = f"Machine with the name {machine.name} already exists" msg = f"Machine with the name {machine.name} already exists"
raise ClanError(msg) raise ClanError(msg)

View File

@@ -166,10 +166,7 @@ def find_reachable_host_from_deploy_json(deploy_json: dict[str, str]) -> str:
host = None host = None
for addr in deploy_json["addrs"]: for addr in deploy_json["addrs"]:
if is_reachable(addr): if is_reachable(addr):
if is_ipv6(addr): host = f"[{addr}]" if is_ipv6(addr) else addr
host = f"[{addr}]"
else:
host = addr
break break
if not host: if not host:
msg = f""" msg = f"""

View File

@@ -80,10 +80,7 @@ class QEMUMonitorProtocol:
self.__sock.listen(1) self.__sock.listen(1)
def __get_sock(self) -> socket.socket: def __get_sock(self) -> socket.socket:
if isinstance(self.__address, tuple): family = socket.AF_INET if isinstance(self.__address, tuple) else socket.AF_UNIX
family = socket.AF_INET
else:
family = socket.AF_UNIX
return socket.socket(family, socket.SOCK_STREAM) return socket.socket(family, socket.SOCK_STREAM)
def __negotiate_capabilities(self) -> dict[str, Any]: def __negotiate_capabilities(self) -> dict[str, Any]:

View File

@@ -4,7 +4,7 @@ import os
import shutil import shutil
import subprocess import subprocess
from collections.abc import Iterator from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager, suppress
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import IO from typing import IO
@@ -196,10 +196,8 @@ def encrypt_file(
with open(meta_path, "w") as f_meta: with open(meta_path, "w") as f_meta:
json.dump(meta, f_meta, indent=2) json.dump(meta, f_meta, indent=2)
finally: finally:
try: with suppress(OSError):
os.remove(f.name) os.remove(f.name)
except OSError:
pass
def decrypt_file(secret_path: Path) -> str: def decrypt_file(secret_path: Path) -> str:

View File

@@ -476,10 +476,7 @@ class Host:
verbose_ssh: bool = False, verbose_ssh: bool = False,
tty: bool = False, tty: bool = False,
) -> list[str]: ) -> list[str]:
if self.user is not None: ssh_target = f"{self.user}@{self.host}" if self.user is not None else self.host
ssh_target = f"{self.user}@{self.host}"
else:
ssh_target = self.host
ssh_opts = ["-A"] if self.forward_agent else [] ssh_opts = ["-A"] if self.forward_agent else []

View File

@@ -37,9 +37,7 @@ def check_vars(machine: Machine, generator_name: None | str = None) -> bool:
log.debug(f"missing_secret_vars: {missing_secret_vars}") log.debug(f"missing_secret_vars: {missing_secret_vars}")
log.debug(f"missing_public_vars: {missing_public_vars}") log.debug(f"missing_public_vars: {missing_public_vars}")
if missing_secret_vars or missing_public_vars: return not (missing_secret_vars or missing_public_vars)
return False
return True
def check_command(args: argparse.Namespace) -> None: def check_command(args: argparse.Namespace) -> None:

View File

@@ -174,12 +174,12 @@ def get_subcommands(
positional_options: list[Option] = [] positional_options: list[Option] = []
subcommands: list[Subcommand] = [] subcommands: list[Subcommand] = []
for action in parser._actions: for action in parser._actions: # noqa: SLF001
if isinstance(action, argparse._HelpAction): if isinstance(action, argparse._HelpAction): # noqa: SLF001
# Pseudoaction that holds the help message # Pseudoaction that holds the help message
continue continue
if isinstance(action, argparse._SubParsersAction): if isinstance(action, argparse._SubParsersAction): # noqa: SLF001
continue # Subparsers handled separately continue # Subparsers handled separately
option_strings = ", ".join(action.option_strings) option_strings = ", ".join(action.option_strings)
@@ -204,8 +204,8 @@ def get_subcommands(
) )
) )
for action in parser._actions: for action in parser._actions: # noqa: SLF001
if isinstance(action, argparse._SubParsersAction): if isinstance(action, argparse._SubParsersAction): # noqa: SLF001
subparsers: dict[str, argparse.ArgumentParser] = action.choices subparsers: dict[str, argparse.ArgumentParser] = action.choices
for name, subparser in subparsers.items(): for name, subparser in subparsers.items():
@@ -252,8 +252,8 @@ def collect_commands() -> list[Category]:
result: list[Category] = [] result: list[Category] = []
for action in parser._actions: for action in parser._actions: # noqa: SLF001
if isinstance(action, argparse._SubParsersAction): if isinstance(action, argparse._SubParsersAction): # noqa: SLF001
subparsers: dict[str, argparse.ArgumentParser] = action.choices subparsers: dict[str, argparse.ArgumentParser] = action.choices
for name, subparser in subparsers.items(): for name, subparser in subparsers.items():
if str(subparser.description).startswith("WIP"): if str(subparser.description).startswith("WIP"):

View File

@@ -1,3 +1,4 @@
import contextlib
import os import os
import signal import signal
import subprocess import subprocess
@@ -46,10 +47,8 @@ class Command:
# We just kill all processes as quickly as possible because we don't # We just kill all processes as quickly as possible because we don't
# care about corrupted state and want to make tests fasts. # care about corrupted state and want to make tests fasts.
for p in reversed(self.processes): for p in reversed(self.processes):
try: with contextlib.suppress(OSError):
os.killpg(os.getpgid(p.pid), signal.SIGKILL) os.killpg(os.getpgid(p.pid), signal.SIGKILL)
except OSError:
pass
@pytest.fixture @pytest.fixture

View File

@@ -48,9 +48,7 @@ def find_dataclasses_in_directory(
if ( if (
isinstance(deco, ast.Name) isinstance(deco, ast.Name)
and deco.id == "dataclass" and deco.id == "dataclass"
): ) or (
dataclass_files.append((file_path, node.name))
elif (
isinstance(deco, ast.Call) isinstance(deco, ast.Call)
and isinstance(deco.func, ast.Name) and isinstance(deco.func, ast.Name)
and deco.func.id == "dataclass" and deco.func.id == "dataclass"

View File

@@ -27,10 +27,9 @@ def test_commit_file(git_repo: Path) -> None:
def test_commit_file_outside_git_raises_error(git_repo: Path) -> None: def test_commit_file_outside_git_raises_error(git_repo: Path) -> None:
# create a file outside the git (a temporary file) # create a file outside the git (a temporary file)
with tempfile.NamedTemporaryFile() as tmp: with tempfile.NamedTemporaryFile() as tmp, pytest.raises(ClanError):
# this should not fail but skip the commit # this should not fail but skip the commit
with pytest.raises(ClanError): git.commit_file(Path(tmp.name), git_repo, "test commit")
git.commit_file(Path(tmp.name), git_repo, "test commit")
def test_commit_file_not_existing_raises_error(git_repo: Path) -> None: def test_commit_file_not_existing_raises_error(git_repo: Path) -> None:

View File

@@ -25,7 +25,7 @@ def test_history_add(
history_file = user_history_file() history_file = user_history_file()
assert history_file.exists() assert history_file.exists()
history = [HistoryEntry(**entry) for entry in json.loads(open(history_file).read())] history = [HistoryEntry(**entry) for entry in json.loads(history_file.read_text())]
assert str(history[0].flake.flake_url) == str(test_flake_with_core.path) assert str(history[0].flake.flake_url) == str(test_flake_with_core.path)

View File

@@ -33,8 +33,8 @@ def test_list_modules(test_flake_with_core: FlakeForTest) -> None:
assert len(modules_info.items()) > 1 assert len(modules_info.items()) > 1
# Random test for those two modules # Random test for those two modules
assert "borgbackup" in modules_info.keys() assert "borgbackup" in modules_info
assert "syncthing" in modules_info.keys() assert "syncthing" in modules_info
@pytest.mark.impure @pytest.mark.impure

View File

@@ -492,10 +492,13 @@ def test_secrets(
"user2", "user2",
] ]
) )
with pytest.raises(ClanError), use_key(age_keys[2].privkey, monkeypatch): with (
pytest.raises(ClanError),
use_key(age_keys[2].privkey, monkeypatch),
capture_output as output,
):
# user2 is not in the group anymore # user2 is not in the group anymore
with capture_output as output: cli.run(["secrets", "get", "--flake", str(test_flake.path), "key"])
cli.run(["secrets", "get", "--flake", str(test_flake.path), "key"])
print(output.out) print(output.out)
cli.run( cli.run(

View File

@@ -279,7 +279,7 @@ class VMObject(GObject.Object):
if not self._log_file: if not self._log_file:
try: try:
self._log_file = open(proc.out_file) self._log_file = open(proc.out_file) # noqa: SIM115
except Exception as ex: except Exception as ex:
log.exception(ex) log.exception(ex)
self._log_file = None self._log_file = None

View File

@@ -64,11 +64,11 @@ class JoinList:
cls._instance = cls.__new__(cls) cls._instance = cls.__new__(cls)
cls.list_store = Gio.ListStore.new(JoinValue) cls.list_store = Gio.ListStore.new(JoinValue)
ClanStore.use().register_on_deep_change(cls._instance._rerender_join_list) ClanStore.use().register_on_deep_change(cls._instance.rerender_join_list)
return cls._instance return cls._instance
def _rerender_join_list( def rerender_join_list(
self, source: GKVStore, position: int, removed: int, added: int self, source: GKVStore, position: int, removed: int, added: int
) -> None: ) -> None:
self.list_store.items_changed( self.list_store.items_changed(

View File

@@ -1,3 +1,4 @@
import contextlib
import os import os
import signal import signal
import subprocess import subprocess
@@ -46,10 +47,8 @@ class Command:
# We just kill all processes as quickly as possible because we don't # We just kill all processes as quickly as possible because we don't
# care about corrupted state and want to make tests fasts. # care about corrupted state and want to make tests fasts.
for p in reversed(self.processes): for p in reversed(self.processes):
try: with contextlib.suppress(OSError):
os.killpg(os.getpgid(p.pid), signal.SIGKILL) os.killpg(os.getpgid(p.pid), signal.SIGKILL)
except OSError:
pass
@pytest.fixture @pytest.fixture

View File

@@ -15,10 +15,7 @@ def send_join_request(host: str, port: int, cert: str) -> bool:
response = send_join_request_api(host, port) response = send_join_request_api(host, port)
if response: if response:
return response return response
if send_join_request_native(host, port, cert): return bool(send_join_request_native(host, port, cert))
return True
return False
# This is the preferred join method, but sunshines pin mechanism # This is the preferred join method, but sunshines pin mechanism

View File

@@ -1,3 +1,4 @@
import contextlib
import os import os
import random import random
import string import string
@@ -72,10 +73,8 @@ def write_state(data: ConfigParser) -> bool:
def add_sunshine_host_to_parser( def add_sunshine_host_to_parser(
config: ConfigParser, hostname: str, manual_host: str, certificate: str, uuid: str config: ConfigParser, hostname: str, manual_host: str, certificate: str, uuid: str
) -> bool: ) -> bool:
try: with contextlib.suppress(DuplicateSectionError):
config.add_section("hosts") config.add_section("hosts")
except DuplicateSectionError:
pass
# amount of hosts # amount of hosts
try: try:

View File

@@ -12,7 +12,7 @@ def get_context() -> http.client.ssl.SSLContext:
# certfile="/home/kenji/.config/sunshine/credentials/cacert.pem", # certfile="/home/kenji/.config/sunshine/credentials/cacert.pem",
# keyfile="/home/kenji/.config/sunshine/credentials/cakey.pem", # keyfile="/home/kenji/.config/sunshine/credentials/cakey.pem",
# ) # )
return http.client.ssl._create_unverified_context() return http.client.ssl._create_unverified_context() # noqa: SLF001
def pair(pin: str) -> str: def pair(pin: str) -> str:
@@ -42,7 +42,9 @@ def pair(pin: str) -> str:
def restart() -> None: def restart() -> None:
# Define the connection # Define the connection
conn = http.client.HTTPSConnection( conn = http.client.HTTPSConnection(
"localhost", 47990, context=http.client.ssl._create_unverified_context() "localhost",
47990,
context=http.client.ssl._create_unverified_context(), # noqa: SLF001
) )
user_and_pass = base64.b64encode(b"sunshine:sunshine").decode("ascii") user_and_pass = base64.b64encode(b"sunshine:sunshine").decode("ascii")
headers = { headers = {

View File

@@ -19,16 +19,13 @@ class Config:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance.config = configparser.ConfigParser() cls._instance.config = configparser.ConfigParser()
config = config_location or cls._instance.default_sunshine_config_file() config = config_location or cls._instance.default_sunshine_config_file()
cls._instance._config_location = config cls._instance.config_location = config
with open(config) as f: with open(config) as f:
config_string = f"[{PSEUDO_SECTION}]\n" + f.read() config_string = f"[{PSEUDO_SECTION}]\n" + f.read()
print(config_string) print(config_string)
cls._instance.config.read_string(config_string) cls._instance.config.read_string(config_string)
return cls._instance return cls._instance
def config_location(self) -> str:
return self._config_location
def default_sunshine_config_dir(self) -> str: def default_sunshine_config_dir(self) -> str:
return os.path.join(os.path.expanduser("~"), ".config", "sunshine") return os.path.join(os.path.expanduser("~"), ".config", "sunshine")

View File

@@ -33,9 +33,25 @@ lint.select = [
"RET", "RET",
"RSE", "RSE",
"RUF", "RUF",
"SIM",
"SLF",
"SLOT",
"T10", "T10",
"TID", "TID",
"U", "U",
"YTT", "YTT",
] ]
lint.ignore = ["E501", "E402", "E731", "ANN101", "ANN401", "A003", "RET504"] lint.ignore = [
"A003",
"ANN101",
"ANN401",
"E402",
"E501",
"E731",
"PT001",
"PT023",
"RET504",
"SIM102",
"SIM108",
"SIM112",
]