Merge pull request 'use pathlib everywhere' (#2023) from type-checking into main

This commit is contained in:
clan-bot
2024-09-02 16:33:46 +00:00
28 changed files with 89 additions and 114 deletions

View File

@@ -2,7 +2,6 @@
import argparse import argparse
import json import json
import logging import logging
import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, get_origin from typing import Any, get_origin
@@ -285,8 +284,8 @@ def set_option(
current[option_path_store[-1]] = casted current[option_path_store[-1]] = casted
# check if there is an existing config file # check if there is an existing config file
if os.path.exists(settings_file): if settings_file.exists():
with open(settings_file) as f: with settings_file.open() as f:
current_config = json.load(f) current_config = json.load(f)
else: else:
current_config = {} current_config = {}
@@ -294,7 +293,7 @@ def set_option(
# merge and save the new config file # merge and save the new config file
new_config = merge(current_config, result) new_config = merge(current_config, result)
settings_file.parent.mkdir(parents=True, exist_ok=True) settings_file.parent.mkdir(parents=True, exist_ok=True)
with open(settings_file, "w") as f: with settings_file.open("w") as f:
json.dump(new_config, f, indent=2) json.dump(new_config, f, indent=2)
print(file=f) # add newline at the end of the file to make git happy print(file=f) # add newline at the end of the file to make git happy

View File

@@ -82,7 +82,7 @@ def config_for_machine(flake_dir: Path, machine_name: str) -> dict:
settings_path = machine_settings_file(flake_dir, machine_name) settings_path = machine_settings_file(flake_dir, machine_name)
if not settings_path.exists(): if not settings_path.exists():
return {} return {}
with open(settings_path) as f: with settings_path.open() as f:
return json.load(f) return json.load(f)
@@ -102,7 +102,7 @@ def set_config_for_machine(flake_dir: Path, machine_name: str, config: dict) ->
# write the config to a json file located at {flake}/machines/{machine_name}/settings.json # write the config to a json file located at {flake}/machines/{machine_name}/settings.json
settings_path = machine_settings_file(flake_dir, machine_name) settings_path = machine_settings_file(flake_dir, machine_name)
settings_path.parent.mkdir(parents=True, exist_ok=True) settings_path.parent.mkdir(parents=True, exist_ok=True)
with open(settings_path, "w") as f: with settings_path.open("w") as f:
json.dump(config, f) json.dump(config, f)
if flake_dir is not None: if flake_dir is not None:

View File

@@ -34,7 +34,7 @@ def clan_key_safe(flake_url: str) -> str:
def find_toplevel(top_level_files: list[str]) -> Path | None: def find_toplevel(top_level_files: list[str]) -> Path | None:
"""Returns the path to the toplevel of the clan flake""" """Returns the path to the toplevel of the clan flake"""
for project_file in top_level_files: for project_file in top_level_files:
initial_path = Path(os.getcwd()) initial_path = Path.cwd()
path = Path(initial_path) path = Path(initial_path)
while path.parent != path: while path.parent != path:
if (path / project_file).exists(): if (path / project_file).exists():
@@ -56,30 +56,30 @@ def clan_templates() -> Path:
def user_config_dir() -> Path: def user_config_dir() -> Path:
if sys.platform == "win32": if sys.platform == "win32":
return Path(os.getenv("APPDATA", os.path.expanduser("~\\AppData\\Roaming\\"))) return Path(os.getenv("APPDATA", Path("~\\AppData\\Roaming\\").expanduser()))
if sys.platform == "darwin": if sys.platform == "darwin":
return Path(os.path.expanduser("~/Library/Application Support/")) return Path("~/Library/Application Support/").expanduser()
return Path(os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))) return Path(os.getenv("XDG_CONFIG_HOME", Path("~/.config").expanduser()))
def user_data_dir() -> Path: def user_data_dir() -> Path:
if sys.platform == "win32": if sys.platform == "win32":
return Path( return Path(
os.getenv("LOCALAPPDATA", os.path.expanduser("~\\AppData\\Local\\")) Path(os.getenv("LOCALAPPDATA", Path("~\\AppData\\Local\\").expanduser()))
) )
if sys.platform == "darwin": if sys.platform == "darwin":
return Path(os.path.expanduser("~/Library/Application Support/")) return Path("~/Library/Application Support/").expanduser()
return Path(os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))) return Path(os.getenv("XDG_DATA_HOME", Path("~/.local/share").expanduser()))
def user_cache_dir() -> Path: def user_cache_dir() -> Path:
if sys.platform == "win32": if sys.platform == "win32":
return Path( return Path(
os.getenv("LOCALAPPDATA", os.path.expanduser("~\\AppData\\Local\\")) Path(os.getenv("LOCALAPPDATA", Path("~\\AppData\\Local\\").expanduser()))
) )
if sys.platform == "darwin": if sys.platform == "darwin":
return Path(os.path.expanduser("~/Library/Caches/")) return Path("~/Library/Caches/").expanduser()
return Path(os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))) return Path(os.getenv("XDG_CACHE_HOME", Path("~/.cache").expanduser()))
def user_gcroot_dir() -> Path: def user_gcroot_dir() -> Path:

View File

@@ -1,4 +1,3 @@
import os
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -30,6 +29,6 @@ class SecretStore(SecretStoreBase):
return (self.dir / service / name).exists() return (self.dir / service / name).exists()
def upload(self, output_dir: Path) -> None: def upload(self, output_dir: Path) -> None:
if os.path.exists(output_dir): if output_dir.exists():
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
shutil.copytree(self.dir, output_dir) shutil.copytree(self.dir, output_dir)

View File

@@ -130,7 +130,7 @@ def load_inventory_json(
inventory_file = get_path(flake_dir) inventory_file = get_path(flake_dir)
if inventory_file.exists(): if inventory_file.exists():
with open(inventory_file) as f: with inventory_file.open() as f:
try: try:
res = json.load(f) res = json.load(f)
inventory = from_dict(Inventory, res) inventory = from_dict(Inventory, res)
@@ -153,7 +153,7 @@ def save_inventory(inventory: Inventory, flake_dir: str | Path, message: str) ->
""" """
inventory_file = get_path(flake_dir) inventory_file = get_path(flake_dir)
with open(inventory_file, "w") as f: with inventory_file.open("w") as f:
json.dump(dataclass_to_dict(inventory), f, indent=2) json.dump(dataclass_to_dict(inventory), f, indent=2)
commit_file(inventory_file, Path(flake_dir), commit_message=message) commit_file(inventory_file, Path(flake_dir), commit_message=message)

View File

@@ -10,11 +10,11 @@ from .jsonrpc import ClanJSONEncoder
@contextmanager @contextmanager
def locked_open(filename: str | Path, mode: str = "r") -> Generator: def locked_open(filename: Path, mode: str = "r") -> Generator:
""" """
This is a context manager that provides an advisory write lock on the file specified by `filename` when entering the context, and releases the lock when leaving the context. The lock is acquired using the `fcntl` module's `LOCK_EX` flag, which applies an exclusive write lock to the file. This is a context manager that provides an advisory write lock on the file specified by `filename` when entering the context, and releases the lock when leaving the context. The lock is acquired using the `fcntl` module's `LOCK_EX` flag, which applies an exclusive write lock to the file.
""" """
with open(filename, mode) as fd: with filename.open(mode) as fd:
fcntl.flock(fd, fcntl.LOCK_EX) fcntl.flock(fd, fcntl.LOCK_EX)
yield fd yield fd
fcntl.flock(fd, fcntl.LOCK_UN) fcntl.flock(fd, fcntl.LOCK_UN)

View File

@@ -185,7 +185,7 @@ def generate_machine_hardware_info(
hw_file.replace(backup_file) hw_file.replace(backup_file)
print(f"Backed up existing {hw_file} to {backup_file}") print(f"Backed up existing {hw_file} to {backup_file}")
with open(hw_file, "w") as f: with hw_file.open("w") as f:
f.write(out.stdout) f.write(out.stdout)
print(f"Successfully generated: {hw_file}") print(f"Successfully generated: {hw_file}")

View File

@@ -122,7 +122,7 @@ class Programs:
@classmethod @classmethod
def is_allowed(cls: type["Programs"], program: str) -> bool: def is_allowed(cls: type["Programs"], program: str) -> bool:
if cls.allowed_programs is None: if cls.allowed_programs is None:
with open(Path(__file__).parent / "allowed-programs.json") as f: with (Path(__file__).parent / "allowed-programs.json").open() as f:
cls.allowed_programs = json.load(f) cls.allowed_programs = json.load(f)
return program in cls.allowed_programs return program in cls.allowed_programs

View File

@@ -42,5 +42,5 @@ def remove_object(path: Path, name: str) -> list[Path]:
msg = f"{name} not found in {path}" msg = f"{name} not found in {path}"
raise ClanError(msg) from e raise ClanError(msg) from e
if not os.listdir(path): if not os.listdir(path):
os.rmdir(path) path.rmdir()
return paths_to_commit return paths_to_commit

View File

@@ -122,7 +122,7 @@ def add_member(
if not user_target.is_symlink(): if not user_target.is_symlink():
msg = f"Cannot add user {name}. {user_target} exists but is not a symlink" msg = f"Cannot add user {name}. {user_target} exists but is not a symlink"
raise ClanError(msg) raise ClanError(msg)
os.remove(user_target) user_target.unlink()
user_target.symlink_to(os.path.relpath(source, user_target.parent)) user_target.symlink_to(os.path.relpath(source, user_target.parent))
return update_group_keys(flake_dir, group_folder.parent.name) return update_group_keys(flake_dir, group_folder.parent.name)
@@ -133,16 +133,16 @@ def remove_member(flake_dir: Path, group_folder: Path, name: str) -> None:
msg = f"{name} does not exist in group in {group_folder}: " msg = f"{name} does not exist in group in {group_folder}: "
msg += list_directory(group_folder) msg += list_directory(group_folder)
raise ClanError(msg) raise ClanError(msg)
os.remove(target) target.unlink()
if len(os.listdir(group_folder)) > 0: if len(os.listdir(group_folder)) > 0:
update_group_keys(flake_dir, group_folder.parent.name) update_group_keys(flake_dir, group_folder.parent.name)
if len(os.listdir(group_folder)) == 0: if len(os.listdir(group_folder)) == 0:
os.rmdir(group_folder) group_folder.rmdir()
if len(os.listdir(group_folder.parent)) == 0: if len(os.listdir(group_folder.parent)) == 0:
os.rmdir(group_folder.parent) group_folder.parent.rmdir()
def add_user(flake_dir: Path, group: str, name: str) -> None: def add_user(flake_dir: Path, group: str, name: str) -> None:

View File

@@ -16,7 +16,7 @@ def extract_public_key(filepath: Path) -> str:
Extracts the public key from a given text file. Extracts the public key from a given text file.
""" """
try: try:
with open(filepath) as file: with filepath.open() as file:
for line in file: for line in file:
# Check if the line contains the public key # Check if the line contains the public key
if line.startswith("# public key:"): if line.startswith("# public key:"):

View File

@@ -218,7 +218,7 @@ def allow_member(
if not user_target.is_symlink(): if not user_target.is_symlink():
msg = f"Cannot add user '{name}' to {group_folder.parent.name} secret. {user_target} exists but is not a symlink" msg = f"Cannot add user '{name}' to {group_folder.parent.name} secret. {user_target} exists but is not a symlink"
raise ClanError(msg) raise ClanError(msg)
os.remove(user_target) user_target.unlink()
user_target.symlink_to(os.path.relpath(source, user_target.parent)) user_target.symlink_to(os.path.relpath(source, user_target.parent))
changed = [user_target] changed = [user_target]
@@ -244,13 +244,13 @@ def disallow_member(group_folder: Path, name: str) -> list[Path]:
if len(keys) < 2: if len(keys) < 2:
msg = f"Cannot remove {name} from {group_folder.parent.name}. No keys left. Use 'clan secrets remove {name}' to remove the secret." msg = f"Cannot remove {name} from {group_folder.parent.name}. No keys left. Use 'clan secrets remove {name}' to remove the secret."
raise ClanError(msg) raise ClanError(msg)
os.remove(target) target.unlink()
if len(os.listdir(group_folder)) == 0: if len(os.listdir(group_folder)) == 0:
os.rmdir(group_folder) group_folder.rmdir()
if len(os.listdir(group_folder.parent)) == 0: if len(os.listdir(group_folder.parent)) == 0:
os.rmdir(group_folder.parent) group_folder.parent.rmdir()
return update_keys( return update_keys(
target.parent.parent, sorted(collect_keys_for_path(group_folder.parent)) target.parent.parent, sorted(collect_keys_for_path(group_folder.parent))
@@ -337,7 +337,7 @@ def rename_command(args: argparse.Namespace) -> None:
if new_path.exists(): if new_path.exists():
msg = f"Secret '{args.new_name}' already exists" msg = f"Secret '{args.new_name}' already exists"
raise ClanError(msg) raise ClanError(msg)
os.rename(old_path, new_path) old_path.rename(new_path)
commit_files( commit_files(
[old_path, new_path], [old_path, new_path],
flake_dir, flake_dir,

View File

@@ -172,13 +172,11 @@ def encrypt_file(
with NamedTemporaryFile(delete=False) as f: with NamedTemporaryFile(delete=False) as f:
try: try:
if isinstance(content, str): if isinstance(content, str):
with open(f.name, "w") as fd: Path(f.name).write_text(content)
fd.write(content)
elif isinstance(content, bytes): elif isinstance(content, bytes):
with open(f.name, "wb") as fd: Path(f.name).write_bytes(content)
fd.write(content)
elif isinstance(content, io.IOBase): elif isinstance(content, io.IOBase):
with open(f.name, "w") as fd: with Path(f.name).open("w") as fd:
shutil.copyfileobj(content, fd) shutil.copyfileobj(content, fd)
else: else:
msg = f"Invalid content type: {type(content)}" msg = f"Invalid content type: {type(content)}"
@@ -191,13 +189,13 @@ def encrypt_file(
# atomic copy of the encrypted file # atomic copy of the encrypted file
with NamedTemporaryFile(dir=folder, delete=False) as f2: with NamedTemporaryFile(dir=folder, delete=False) as f2:
shutil.copyfile(f.name, f2.name) shutil.copyfile(f.name, f2.name)
os.rename(f2.name, secret_path) Path(f2.name).rename(secret_path)
meta_path = secret_path.parent / "meta.json" meta_path = secret_path.parent / "meta.json"
with open(meta_path, "w") as f_meta: with meta_path.open("w") as f_meta:
json.dump(meta, f_meta, indent=2) json.dump(meta, f_meta, indent=2)
finally: finally:
with suppress(OSError): with suppress(OSError):
os.remove(f.name) Path(f.name).unlink()
def decrypt_file(secret_path: Path) -> str: def decrypt_file(secret_path: Path) -> str:
@@ -214,7 +212,7 @@ def get_meta(secret_path: Path) -> dict:
meta_path = secret_path.parent / "meta.json" meta_path = secret_path.parent / "meta.json"
if not meta_path.exists(): if not meta_path.exists():
return {} return {}
with open(meta_path) as f: with meta_path.open() as f:
return json.load(f) return json.load(f)
@@ -233,7 +231,7 @@ def write_key(path: Path, publickey: str, overwrite: bool) -> None:
def read_key(path: Path) -> str: def read_key(path: Path) -> str:
with open(path / "key.json") as f: with Path(path / "key.json").open() as f:
try: try:
key = json.load(f) key = json.load(f)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:

View File

@@ -1,5 +1,4 @@
import argparse import argparse
import os
import re import re
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
@@ -20,7 +19,7 @@ def secret_name_type(arg_value: str) -> str:
def public_or_private_age_key_type(arg_value: str) -> str: def public_or_private_age_key_type(arg_value: str) -> str:
if os.path.isfile(arg_value): if Path(arg_value).is_file():
arg_value = Path(arg_value).read_text().strip() arg_value = Path(arg_value).read_text().strip()
if arg_value.startswith("age1"): if arg_value.startswith("age1"):
return arg_value.strip() return arg_value.strip()

View File

@@ -1,4 +1,3 @@
import os
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -36,6 +35,6 @@ class SecretStore(SecretStoreBase):
return (self.dir / service / name).exists() return (self.dir / service / name).exists()
def upload(self, output_dir: Path) -> None: def upload(self, output_dir: Path) -> None:
if os.path.exists(output_dir): if output_dir.exists():
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
shutil.copytree(self.dir, output_dir) shutil.copytree(self.dir, output_dir)

View File

@@ -1,4 +1,3 @@
import os
import random import random
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
@@ -43,7 +42,7 @@ def graphics_options(vm: VmConfig) -> GraphicOptions:
#"-chardev", "socket,id=vgpu,path=/tmp/vgpu.sock", #"-chardev", "socket,id=vgpu,path=/tmp/vgpu.sock",
], cid) ], cid)
# fmt: on # fmt: on
if not os.path.exists("/run/opengl-driver"): if not Path("/run/opengl-driver").exists():
display_options = [ display_options = [
"-vga", "-vga",
"none", "none",

View File

@@ -313,7 +313,7 @@ def build_command_reference() -> None:
markdown += "</div>" markdown += "</div>"
markdown += "\n" markdown += "\n"
with open(folder / "index.md", "w") as f: with (folder / "index.md").open("w") as f:
f.write(markdown) f.write(markdown)
# Each top level category is a separate file # Each top level category is a separate file
@@ -374,7 +374,7 @@ def build_command_reference() -> None:
files[folder / f"{filename}.md"] = markdown files[folder / f"{filename}.md"] = markdown
for fname, content in files.items(): for fname, content in files.items():
with open(fname, "w") as f: with fname.open("w") as f:
f.write(content) f.write(content)

View File

@@ -12,7 +12,7 @@ from clan_cli.errors import ClanError
def find_dataclasses_in_directory( def find_dataclasses_in_directory(
directory: Path, exclude_paths: list[str] | None = None directory: Path, exclude_paths: list[str] | None = None
) -> list[tuple[str, str]]: ) -> list[tuple[Path, str]]:
""" """
Find all dataclass classes in all Python files within a nested directory. Find all dataclass classes in all Python files within a nested directory.
@@ -26,42 +26,41 @@ def find_dataclasses_in_directory(
exclude_paths = [] exclude_paths = []
dataclass_files = [] dataclass_files = []
excludes = [os.path.join(directory, d) for d in exclude_paths] excludes = [directory / d for d in exclude_paths]
for root, _, files in os.walk(directory, topdown=False): for root, _, files in os.walk(directory, topdown=False):
for file in files: for file in files:
if not file.endswith(".py"): if not file.endswith(".py"):
continue continue
file_path = os.path.join(root, file) file_path = Path(root) / file
if file_path in excludes: if file_path in excludes:
print(f"Skipping dataclass check for file: {file_path}") print(f"Skipping dataclass check for file: {file_path}")
continue continue
with open(file_path, encoding="utf-8") as f: python_code = file_path.read_text()
try: try:
tree = ast.parse(f.read(), filename=file_path) tree = ast.parse(python_code, filename=file_path)
for node in ast.walk(tree): for node in ast.walk(tree):
if isinstance(node, ast.ClassDef): if isinstance(node, ast.ClassDef):
for deco in node.decorator_list: for deco in node.decorator_list:
if ( if (
isinstance(deco, ast.Name) isinstance(deco, ast.Name) and deco.id == "dataclass"
and deco.id == "dataclass" ) or (
) or ( isinstance(deco, ast.Call)
isinstance(deco, ast.Call) and isinstance(deco.func, ast.Name)
and isinstance(deco.func, ast.Name) and deco.func.id == "dataclass"
and deco.func.id == "dataclass" ):
): dataclass_files.append((file_path, node.name))
dataclass_files.append((file_path, node.name)) except (SyntaxError, UnicodeDecodeError) as e:
except (SyntaxError, UnicodeDecodeError) as e: print(f"Error parsing {file_path}: {e}")
print(f"Error parsing {file_path}: {e}")
return dataclass_files return dataclass_files
def load_dataclass_from_file( def load_dataclass_from_file(
file_path: str, class_name: str, root_dir: str file_path: Path, class_name: str, root_dir: str
) -> type | None: ) -> type | None:
""" """
Load a dataclass from a given file path. Load a dataclass from a given file path.

View File

@@ -33,10 +33,9 @@ def test_create_flake(
# create a hardware-configuration.nix that doesn't throw an eval error # create a hardware-configuration.nix that doesn't throw an eval error
for patch_machine in ["jon", "sara"]: for patch_machine in ["jon", "sara"]:
with open( (
flake_dir / "machines" / f"{patch_machine}/hardware-configuration.nix", "w" flake_dir / "machines" / f"{patch_machine}/hardware-configuration.nix"
) as hw_config_nix: ).write_text("{}")
hw_config_nix.write("{}")
with capture_output as output: with capture_output as output:
cli.run(["machines", "list"]) cli.run(["machines", "list"])

View File

@@ -1,4 +1,3 @@
import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -13,7 +12,7 @@ from stdout import CaptureOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from age_keys import KeyPair from age_keys import KeyPair
no_kvm = not os.path.exists("/dev/kvm") no_kvm = not Path("/dev/kvm").exists()
@pytest.mark.impure @pytest.mark.impure

View File

@@ -64,7 +64,7 @@ def _init_proc(
os.setsid() os.setsid()
# Open stdout and stderr # Open stdout and stderr
with open(out_file, "w") as out_fd: with out_file.open("w") as out_fd:
os.dup2(out_fd.fileno(), sys.stdout.fileno()) os.dup2(out_fd.fileno(), sys.stdout.fileno())
os.dup2(out_fd.fileno(), sys.stderr.fileno()) os.dup2(out_fd.fileno(), sys.stderr.fileno())

View File

@@ -21,6 +21,7 @@
import os import os
import sys import sys
import tempfile
from collections.abc import Callable from collections.abc import Callable
from typing import Any, ClassVar from typing import Any, ClassVar
@@ -884,14 +885,8 @@ class Win32Implementation(BaseImplementation):
# No custom icons present, fall back to default icons # No custom icons present, fall back to default icons
ico_buffer = self._load_ico_buffer(icon_name, icon_size) ico_buffer = self._load_ico_buffer(icon_name, icon_size)
try: with tempfile.NamedTemporaryFile(delete=False) as file_handle:
import tempfile file_handle.write(ico_buffer)
file_handle = tempfile.NamedTemporaryFile(delete=False)
with file_handle:
file_handle.write(ico_buffer)
return windll.user32.LoadImageA( return windll.user32.LoadImageA(
0, 0,
encode_path(file_handle.name), encode_path(file_handle.name),
@@ -901,9 +896,6 @@ class Win32Implementation(BaseImplementation):
self.LR_LOADFROMFILE, self.LR_LOADFROMFILE,
) )
finally:
os.remove(file_handle.name)
def _destroy_h_icon(self): def _destroy_h_icon(self):
from ctypes import windll from ctypes import windll

View File

@@ -279,7 +279,7 @@ class VMObject(GObject.Object):
if not self._log_file: if not self._log_file:
try: try:
self._log_file = open(proc.out_file) # noqa: SIM115 self._log_file = Path(proc.out_file).open() # noqa: SIM115
except Exception as ex: except Exception as ex:
log.exception(ex) log.exception(ex)
self._log_file = None self._log_file = None

View File

@@ -273,8 +273,7 @@ class ClanList(Gtk.Box):
logs.set_title(f"""📄<span weight="normal"> {name}</span>""") logs.set_title(f"""📄<span weight="normal"> {name}</span>""")
# initial message. Streaming happens automatically when the file is changed by the build process # initial message. Streaming happens automatically when the file is changed by the build process
with open(vm.build_process.out_file) as f: logs.set_message(vm.build_process.out_file.read_text())
logs.set_message(f.read())
views.set_visible_child_name("logs") views.set_visible_child_name("logs")

View File

@@ -3,6 +3,7 @@ import argparse
import json import json
from collections.abc import Callable from collections.abc import Callable
from functools import partial from functools import partial
from pathlib import Path
from typing import Any from typing import Any
@@ -309,11 +310,11 @@ def generate_dataclass(schema: dict[str, Any], class_name: str = root_class) ->
def run_gen(args: argparse.Namespace) -> None: def run_gen(args: argparse.Namespace) -> None:
print(f"Converting {args.input} to {args.output}") print(f"Converting {args.input} to {args.output}")
dataclass_code = "" dataclass_code = ""
with open(args.input) as f: with args.input.open() as f:
schema = json.load(f) schema = json.load(f)
dataclass_code = generate_dataclass(schema) dataclass_code = generate_dataclass(schema)
with open(args.output, "w") as f: with args.output.open("w") as f:
f.write( f.write(
"""# DON NOT EDIT THIS FILE MANUALLY. IT IS GENERATED. """# DON NOT EDIT THIS FILE MANUALLY. IT IS GENERATED.
# #
@@ -330,8 +331,8 @@ from typing import Any\n\n
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input", help="Input JSON schema file") parser.add_argument("input", help="Input JSON schema file", type=Path)
parser.add_argument("output", help="Output Python file") parser.add_argument("output", help="Output Python file", type=Path)
parser.set_defaults(func=run_gen) parser.set_defaults(func=run_gen)
args = parser.parse_args() args = parser.parse_args()

View File

@@ -1,22 +1,15 @@
import argparse import argparse
from pathlib import Path
from .state import init_state from .state import init_state
def read_file(file_path: str) -> str:
with open(file_path) as file:
return file.read()
def init_config(args: argparse.Namespace) -> None: def init_config(args: argparse.Namespace) -> None:
key = read_file(args.key) init_state(args.certificate.read_text(), args.key.read_text())
certificate = read_file(args.certificate)
init_state(certificate, key)
print("Finished initializing moonlight state.") print("Finished initializing moonlight state.")
def register_config_initialization_parser(parser: argparse.ArgumentParser) -> None: def register_config_initialization_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--certificate") parser.add_argument("--certificate", type=Path)
parser.add_argument("--key") parser.add_argument("--key", type=Path)
parser.set_defaults(func=init_config) parser.set_defaults(func=init_config)

View File

@@ -1,5 +1,4 @@
import contextlib import contextlib
import os
import random import random
import string import string
from configparser import ConfigParser, DuplicateSectionError, NoOptionError from configparser import ConfigParser, DuplicateSectionError, NoOptionError
@@ -45,12 +44,12 @@ def convert_bytearray_to_string(byte_array: str) -> str:
# this must be created before moonlight is first run # this must be created before moonlight is first run
def init_state(certificate: str, key: str) -> None: def init_state(certificate: str, key: str) -> None:
print("Initializing moonlight state.") print("Initializing moonlight state.")
os.makedirs(moonlight_config_dir(), exist_ok=True) moonlight_config_dir().mkdir(parents=True, exist_ok=True)
print("Initialized moonlight config directory.") print("Initialized moonlight config directory.")
print("Writing moonlight state file.") print("Writing moonlight state file.")
# write the initial bootstrap config file # write the initial bootstrap config file
with open(moonlight_state_file(), "w") as file: with moonlight_state_file().open("w") as file:
config = ConfigParser() config = ConfigParser()
# bytearray ojbects are not supported by ConfigParser, # bytearray ojbects are not supported by ConfigParser,
# so we need to adjust them ourselves # so we need to adjust them ourselves

View File

@@ -1,5 +1,5 @@
[tool.mypy] [tool.mypy]
python_version = "3.10" python_version = "3.11"
pretty = true pretty = true
warn_redundant_casts = true warn_redundant_casts = true
disallow_untyped_calls = true disallow_untyped_calls = true
@@ -28,6 +28,7 @@ lint.select = [
"N", "N",
"PIE", "PIE",
"PT", "PT",
"PTH",
"PYI", "PYI",
"Q", "Q",
"RET", "RET",