diff --git a/.gitignore b/.gitignore index 94f2e3b20..e40f14cca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .direnv +democlan result* /pkgs/clan-cli/clan_cli/nixpkgs /pkgs/clan-cli/clan_cli/webui/assets diff --git a/pkgs/clan-cli/.vscode/launch.json b/pkgs/clan-cli/.vscode/launch.json index ab2ef11e6..4e2c20a75 100644 --- a/pkgs/clan-cli/.vscode/launch.json +++ b/pkgs/clan-cli/.vscode/launch.json @@ -12,6 +12,15 @@ "justMyCode": false, "args": [ "--reload", "--no-open", "--log-level", "debug" ], + }, + { + "name": "Clan Cli VMs", + "type": "python", + "request": "launch", + "module": "clan_cli", + "justMyCode": false, + "args": [ "vms" ], + } ] } \ No newline at end of file diff --git a/pkgs/clan-cli/clan_cli/__init__.py b/pkgs/clan-cli/clan_cli/__init__.py index bfbe083e8..7cd2c3a28 100644 --- a/pkgs/clan-cli/clan_cli/__init__.py +++ b/pkgs/clan-cli/clan_cli/__init__.py @@ -1,12 +1,15 @@ import argparse +import logging import sys from types import ModuleType from typing import Optional -from . import config, create, machines, secrets, vms, webui +from . import config, create, custom_logger, machines, secrets, vms, webui from .errors import ClanError from .ssh import cli as ssh_cli +log = logging.getLogger(__name__) + argcomplete: Optional[ModuleType] = None try: import argcomplete # type: ignore[no-redef] @@ -62,14 +65,20 @@ def create_parser(prog: Optional[str] = None) -> argparse.ArgumentParser: def main() -> None: parser = create_parser() args = parser.parse_args() + + if args.debug: + custom_logger.register(logging.DEBUG) + log.debug("Debug logging enabled") + else: + custom_logger.register(logging.INFO) + if not hasattr(args, "func"): + log.error("No argparse function registered") return try: args.func(args) except ClanError as e: - if args.debug: - raise - print(f"{sys.argv[0]}: {e}") + log.exception(e) sys.exit(1) diff --git a/pkgs/clan-cli/clan_cli/__main__.py b/pkgs/clan-cli/clan_cli/__main__.py new file mode 100644 index 000000000..868d99efc --- /dev/null +++ b/pkgs/clan-cli/clan_cli/__main__.py @@ -0,0 +1,4 @@ +from . import main + +if __name__ == "__main__": + main() diff --git a/pkgs/clan-cli/clan_cli/async_cmd.py b/pkgs/clan-cli/clan_cli/async_cmd.py new file mode 100644 index 000000000..e9e2dba53 --- /dev/null +++ b/pkgs/clan-cli/clan_cli/async_cmd.py @@ -0,0 +1,28 @@ +import asyncio +import logging +import shlex + +from .errors import ClanError + +log = logging.getLogger(__name__) + + +async def run(cmd: list[str]) -> bytes: + log.debug(f"$: {shlex.join(cmd)}") + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + + if proc.returncode != 0: + raise ClanError( + f""" +command: {shlex.join(cmd)} +exit code: {proc.returncode} +command output: +{stderr.decode("utf-8")} +""" + ) + return stdout diff --git a/pkgs/clan-cli/clan_cli/machines/list.py b/pkgs/clan-cli/clan_cli/machines/list.py index dc4755f69..ae8b1d3b1 100644 --- a/pkgs/clan-cli/clan_cli/machines/list.py +++ b/pkgs/clan-cli/clan_cli/machines/list.py @@ -1,12 +1,16 @@ import argparse +import logging import os from .folders import machines_folder from .types import validate_hostname +log = logging.getLogger(__name__) + def list_machines() -> list[str]: path = machines_folder() + log.debug(f"Listing machines in {path}") if not path.exists(): return [] objs: list[str] = [] diff --git a/pkgs/clan-cli/clan_cli/secrets/generate.py b/pkgs/clan-cli/clan_cli/secrets/generate.py index 9e47c93cf..8433200f0 100644 --- a/pkgs/clan-cli/clan_cli/secrets/generate.py +++ b/pkgs/clan-cli/clan_cli/secrets/generate.py @@ -1,14 +1,18 @@ import argparse +import logging import os import shlex import subprocess +import sys from pathlib import Path from clan_cli.errors import ClanError -from ..dirs import get_clan_flake_toplevel, module_root +from ..dirs import get_clan_flake_toplevel from ..nix import nix_build, nix_config +log = logging.getLogger(__name__) + def build_generate_script(machine: str, clan_dir: Path) -> str: config = nix_config() @@ -31,7 +35,8 @@ def build_generate_script(machine: str, clan_dir: Path) -> str: def run_generate_secrets(secret_generator_script: str, clan_dir: Path) -> None: env = os.environ.copy() env["CLAN_DIR"] = str(clan_dir) - env["PYTHONPATH"] = str(module_root().parent) # TODO do this in the clanCore module + env["PYTHONPATH"] = ":".join(sys.path) # TODO do this in the clanCore module + print(f"generating secrets... {secret_generator_script}") proc = subprocess.run( [secret_generator_script], @@ -39,6 +44,8 @@ def run_generate_secrets(secret_generator_script: str, clan_dir: Path) -> None: ) if proc.returncode != 0: + log.error("stdout: %s", proc.stdout) + log.error("stderr: %s", proc.stderr) raise ClanError("failed to generate secrets") else: print("successfully generated secrets") diff --git a/pkgs/clan-cli/clan_cli/secrets/upload.py b/pkgs/clan-cli/clan_cli/secrets/upload.py index 44aac77b5..69ff7bcee 100644 --- a/pkgs/clan-cli/clan_cli/secrets/upload.py +++ b/pkgs/clan-cli/clan_cli/secrets/upload.py @@ -1,16 +1,20 @@ import argparse import json +import logging import os import shlex import subprocess +import sys from pathlib import Path from tempfile import TemporaryDirectory -from ..dirs import get_clan_flake_toplevel, module_root +from ..dirs import get_clan_flake_toplevel from ..errors import ClanError from ..nix import nix_build, nix_config, nix_shell from ..ssh import parse_deployment_address +log = logging.getLogger(__name__) + def build_upload_script(machine: str, clan_dir: Path) -> str: config = nix_config() @@ -53,7 +57,7 @@ def run_upload_secrets( ) -> None: env = os.environ.copy() env["CLAN_DIR"] = str(clan_dir) - env["PYTHONPATH"] = str(module_root().parent) # TODO do this in the clanCore module + env["PYTHONPATH"] = ":".join(sys.path) # TODO do this in the clanCore module print(f"uploading secrets... {flake_attr}") with TemporaryDirectory() as tempdir_: tempdir = Path(tempdir_) @@ -67,6 +71,8 @@ def run_upload_secrets( ) if proc.returncode != 0: + log.error("Stdout: %s", proc.stdout) + log.error("Stderr: %s", proc.stderr) raise ClanError("failed to upload secrets") h = parse_deployment_address(flake_attr, target) diff --git a/pkgs/clan-cli/clan_cli/task_manager.py b/pkgs/clan-cli/clan_cli/task_manager.py new file mode 100644 index 000000000..f73ff61c3 --- /dev/null +++ b/pkgs/clan-cli/clan_cli/task_manager.py @@ -0,0 +1,169 @@ +import logging +import os +import queue +import select +import shlex +import subprocess +import threading +from typing import Any, Iterator +from uuid import UUID, uuid4 + + +class CmdState: + def __init__(self, log: logging.Logger) -> None: + self.log: logging.Logger = log + self.p: subprocess.Popen | None = None + self.stdout: list[str] = [] + self.stderr: list[str] = [] + self._output: queue.SimpleQueue = queue.SimpleQueue() + self.returncode: int | None = None + self.done: bool = False + self.running: bool = False + self.cmd_str: str | None = None + self.workdir: str | None = None + + def close_queue(self) -> None: + if self.p is not None: + self.returncode = self.p.returncode + self._output.put(None) + self.running = False + self.done = True + + def run(self, cmd: list[str]) -> None: + self.running = True + try: + self.cmd_str = shlex.join(cmd) + self.workdir = os.getcwd() + self.log.debug(f"Working directory: {self.workdir}") + self.log.debug(f"Running command: {shlex.join(cmd)}") + self.p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + cwd=self.workdir, + ) + + while self.p.poll() is None: + # Check if stderr is ready to be read from + rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0) + if self.p.stderr in rlist: + assert self.p.stderr is not None + line = self.p.stderr.readline() + if line != "": + line = line.strip("\n") + self.stderr.append(line) + self.log.debug("stderr: %s", line) + self._output.put(line + "\n") + + if self.p.stdout in rlist: + assert self.p.stdout is not None + line = self.p.stdout.readline() + if line != "": + line = line.strip("\n") + self.stdout.append(line) + self.log.debug("stdout: %s", line) + self._output.put(line + "\n") + + if self.p.returncode != 0: + raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}") + + self.log.debug("Successfully ran command") + finally: + self.close_queue() + + +class BaseTask(threading.Thread): + def __init__(self, uuid: UUID) -> None: + # calling parent class constructor + threading.Thread.__init__(self) + + # constructor + self.uuid: UUID = uuid + self.log = logging.getLogger(__name__) + self.procs: list[CmdState] = [] + self.failed: bool = False + self.finished: bool = False + self.logs_lock = threading.Lock() + + def run(self) -> None: + try: + self.task_run() + except Exception as e: + for proc in self.procs: + proc.close_queue() + self.failed = True + self.log.exception(e) + finally: + self.finished = True + + def task_run(self) -> None: + raise NotImplementedError + + ## TODO: If two clients are connected to the same task, + def logs_iter(self) -> Iterator[str]: + with self.logs_lock: + for proc in self.procs: + if self.finished: + self.log.debug("log iter: Task is finished") + break + if proc.done: + for line in proc.stderr: + yield line + "\n" + for line in proc.stdout: + yield line + "\n" + continue + while True: + out = proc._output + line = out.get() + if line is None: + break + yield line + + def register_cmds(self, num_cmds: int) -> Iterator[CmdState]: + for i in range(num_cmds): + cmd = CmdState(self.log) + self.procs.append(cmd) + + for cmd in self.procs: + yield cmd + + +# TODO: We need to test concurrency +class TaskPool: + def __init__(self) -> None: + self.lock: threading.RLock = threading.RLock() + self.pool: dict[UUID, BaseTask] = {} + + def __getitem__(self, uuid: UUID) -> BaseTask: + with self.lock: + return self.pool[uuid] + + def __setitem__(self, uuid: UUID, task: BaseTask) -> None: + with self.lock: + if uuid in self.pool: + raise KeyError(f"Task with uuid {uuid} already exists") + if type(uuid) is not UUID: + raise TypeError("uuid must be of type UUID") + self.pool[uuid] = task + + +POOL: TaskPool = TaskPool() + + +def get_task(uuid: UUID) -> BaseTask: + global POOL + return POOL[uuid] + + +def register_task(task: type, *args: Any) -> UUID: + global POOL + if not issubclass(task, BaseTask): + raise TypeError("task must be a subclass of BaseTask") + + uuid = uuid4() + + inst_task = task(uuid, *args) + POOL[uuid] = inst_task + inst_task.start() + return uuid diff --git a/pkgs/clan-cli/clan_cli/vms/create.py b/pkgs/clan-cli/clan_cli/vms/create.py index 93ffa6b58..b3eb51d4e 100644 --- a/pkgs/clan-cli/clan_cli/vms/create.py +++ b/pkgs/clan-cli/clan_cli/vms/create.py @@ -1,101 +1,129 @@ import argparse import json -import subprocess import tempfile from pathlib import Path +from typing import Iterator +from uuid import UUID from ..dirs import get_clan_flake_toplevel from ..nix import nix_build, nix_shell +from ..task_manager import BaseTask, CmdState, get_task, register_task +from .inspect import VmConfig -def get_vm_create_info(machine: str) -> dict: +class BuildVmTask(BaseTask): + def __init__(self, uuid: UUID, vm: VmConfig) -> None: + super().__init__(uuid) + self.vm = vm + + def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict: + clan_dir = self.vm.flake_url + machine = self.vm.flake_attr + cmd = next(cmds) + cmd.run( + nix_build( + [ + # f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this + f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create' + ] + ) + ) + vm_json = "".join(cmd.stdout) + self.log.debug(f"VM JSON path: {vm_json}") + with open(vm_json) as f: + return json.load(f) + + def task_run(self) -> None: + cmds = self.register_cmds(4) + + machine = self.vm.flake_attr + self.log.debug(f"Creating VM for {machine}") + + # TODO: We should get this from the vm argument + vm_config = self.get_vm_create_info(cmds) + + with tempfile.TemporaryDirectory() as tmpdir_: + xchg_dir = Path(tmpdir_) / "xchg" + xchg_dir.mkdir() + disk_img = f"{tmpdir_}/disk.img" + + cmd = next(cmds) + cmd.run( + nix_shell( + ["qemu"], + [ + "qemu-img", + "create", + "-f", + "raw", + disk_img, + "1024M", + ], + ) + ) + + cmd = next(cmds) + cmd.run( + nix_shell( + ["e2fsprogs"], + [ + "mkfs.ext4", + "-L", + "nixos", + disk_img, + ], + ) + ) + + cmd = next(cmds) + cmd.run( + nix_shell( + ["qemu"], + [ + # fmt: off + "qemu-kvm", + "-name", machine, + "-m", f'{vm_config["memorySize"]}M', + "-smp", str(vm_config["cores"]), + "-device", "virtio-rng-pci", + "-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0", + "-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store", + "-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared", + "-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg", + "-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report', + "-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root", + "-device", "virtio-keyboard", + "-usb", + "-device", "usb-tablet,bus=usb-bus.0", + "-kernel", f'{vm_config["toplevel"]}/kernel', + "-initrd", vm_config["initrd"], + "-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0', + # fmt: on + ], + ) + ) + + +def create_vm(vm: VmConfig) -> UUID: + return register_task(BuildVmTask, vm) + + +def create_command(args: argparse.Namespace) -> None: clan_dir = get_clan_flake_toplevel().as_posix() + vm = VmConfig( + flake_url=clan_dir, + flake_attr=args.machine, + cores=0, + graphics=False, + memory_size=0, + ) - # config = nix_config() - # system = config["system"] - - vm_json = subprocess.run( - nix_build( - [ - # f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this - f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create' - ] - ), - stdout=subprocess.PIPE, - check=True, - text=True, - ).stdout.strip() - with open(vm_json) as f: - return json.load(f) - - -def create(args: argparse.Namespace) -> None: - print(f"Creating VM for {args.machine}") - machine = args.machine - vm_config = get_vm_create_info(machine) - with tempfile.TemporaryDirectory() as tmpdir_: - xchg_dir = Path(tmpdir_) / "xchg" - xchg_dir.mkdir() - disk_img = f"{tmpdir_}/disk.img" - subprocess.run( - nix_shell( - ["qemu"], - [ - "qemu-img", - "create", - "-f", - "raw", - disk_img, - "1024M", - ], - ), - stdout=subprocess.PIPE, - check=True, - text=True, - ) - subprocess.run( - [ - "mkfs.ext4", - "-L", - "nixos", - disk_img, - ], - stdout=subprocess.PIPE, - check=True, - text=True, - ) - - subprocess.run( - nix_shell( - ["qemu"], - [ - # fmt: off - "qemu-kvm", - "-name", machine, - "-m", f'{vm_config["memorySize"]}M', - "-smp", str(vm_config["cores"]), - "-device", "virtio-rng-pci", - "-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0", - "-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store", - "-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared", - "-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg", - "-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report', - "-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root", - "-device", "virtio-keyboard", - "-usb", - "-device", "usb-tablet,bus=usb-bus.0", - "-kernel", f'{vm_config["toplevel"]}/kernel', - "-initrd", vm_config["initrd"], - "-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0', - # fmt: on - ], - ), - stdout=subprocess.PIPE, - check=True, - text=True, - ) + uuid = create_vm(vm) + task = get_task(uuid) + for line in task.logs_iter(): + print(line, end="") def register_create_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument("machine", type=str) - parser.set_defaults(func=create) + parser.set_defaults(func=create_command) diff --git a/pkgs/clan-cli/clan_cli/vms/inspect.py b/pkgs/clan-cli/clan_cli/vms/inspect.py index 67e5fedc8..9b8559a75 100644 --- a/pkgs/clan-cli/clan_cli/vms/inspect.py +++ b/pkgs/clan-cli/clan_cli/vms/inspect.py @@ -1,38 +1,42 @@ import argparse +import asyncio import json -import subprocess +from pydantic import BaseModel + +from ..async_cmd import run from ..dirs import get_clan_flake_toplevel from ..nix import nix_eval -def get_vm_inspect_info(machine: str) -> dict: - clan_dir = get_clan_flake_toplevel().as_posix() +class VmConfig(BaseModel): + flake_url: str + flake_attr: str - # config = nix_config() - # system = config["system"] + cores: int + memory_size: int + graphics: bool - return json.loads( - subprocess.run( - nix_eval( - [ - # f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation' # TODO use this - f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.config' - ] - ), - stdout=subprocess.PIPE, - check=True, - text=True, - ).stdout + +async def inspect_vm(flake_url: str, flake_attr: str) -> VmConfig: + cmd = nix_eval( + [ + f"{flake_url}#nixosConfigurations.{json.dumps(flake_attr)}.config.system.clan.vm.config" + ] ) + stdout = await run(cmd) + data = json.loads(stdout) + return VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data) -def inspect(args: argparse.Namespace) -> None: - print(f"Creating VM for {args.machine}") - machine = args.machine - print(get_vm_inspect_info(machine)) +def inspect_command(args: argparse.Namespace) -> None: + clan_dir = get_clan_flake_toplevel().as_posix() + res = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine)) + print("Cores:", res.cores) + print("Memory size:", res.memory_size) + print("Graphics:", res.graphics) def register_inspect_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument("machine", type=str) - parser.set_defaults(func=inspect) + parser.set_defaults(func=inspect_command) diff --git a/pkgs/clan-cli/clan_cli/webui/__init__.py b/pkgs/clan-cli/clan_cli/webui/__init__.py index fc1d8ca55..ca71979ed 100644 --- a/pkgs/clan-cli/clan_cli/webui/__init__.py +++ b/pkgs/clan-cli/clan_cli/webui/__init__.py @@ -45,6 +45,8 @@ def register_parser(parser: argparse.ArgumentParser) -> None: help="Log level", choices=["critical", "error", "warning", "info", "debug", "trace"], ) + + # Set the args.func variable in args if start_server is None: parser.set_defaults(func=fastapi_is_not_installed) else: diff --git a/pkgs/clan-cli/clan_cli/webui/__main__.py b/pkgs/clan-cli/clan_cli/webui/__main__.py index c551d7042..f6bd9ea79 100644 --- a/pkgs/clan-cli/clan_cli/webui/__main__.py +++ b/pkgs/clan-cli/clan_cli/webui/__main__.py @@ -5,6 +5,11 @@ from . import register_parser if __name__ == "__main__": # this is use in our integration test parser = argparse.ArgumentParser() + # call the register_parser function, which adds arguments to the parser register_parser(parser) args = parser.parse_args() + + # call the function that is stored + # in the func attribute of args, and pass args as the argument + # look into register_parser to see how this is done args.func(args) diff --git a/pkgs/clan-cli/clan_cli/webui/app.py b/pkgs/clan-cli/clan_cli/webui/app.py index b3efaa603..d399577e1 100644 --- a/pkgs/clan-cli/clan_cli/webui/app.py +++ b/pkgs/clan-cli/clan_cli/webui/app.py @@ -5,9 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.routing import APIRoute from fastapi.staticfiles import StaticFiles -from .. import custom_logger +from ..errors import ClanError from .assets import asset_path -from .routers import flake, health, machines, root, utils, vms +from .error_handlers import clan_error_handler +from .routers import flake, health, machines, root, vms origins = [ "http://localhost:3000", @@ -33,9 +34,7 @@ def setup_app() -> FastAPI: # Needs to be last in register. Because of wildcard route app.include_router(root.router) - app.add_exception_handler( - utils.NixBuildException, utils.nix_build_exception_handler - ) + app.add_exception_handler(ClanError, clan_error_handler) app.mount("/static", StaticFiles(directory=asset_path()), name="static") @@ -43,15 +42,11 @@ def setup_app() -> FastAPI: if isinstance(route, APIRoute): route.operation_id = route.name # in this case, 'read_items' log.debug(f"Registered route: {route}") + + for i in app.exception_handlers.items(): + log.debug(f"Registered exception handler: {i}") + return app -# TODO: How do I get the log level from the command line in here? -custom_logger.register(logging.DEBUG) app = setup_app() - -for i in app.exception_handlers.items(): - log.info(f"Registered exception handler: {i}") - -log.warning("log warn") -log.debug("log debug") diff --git a/pkgs/clan-cli/clan_cli/webui/assets.py b/pkgs/clan-cli/clan_cli/webui/assets.py index b6a027c4b..4e1de38ec 100644 --- a/pkgs/clan-cli/clan_cli/webui/assets.py +++ b/pkgs/clan-cli/clan_cli/webui/assets.py @@ -1,7 +1,39 @@ import functools +import logging from pathlib import Path +log = logging.getLogger(__name__) + + +def get_hash(string: str) -> str: + """ + This function takes a string like '/nix/store/kkvk20b8zh8aafdnfjp6dnf062x19732-source' + and returns the hash part 'kkvk20b8zh8aafdnfjp6dnf062x19732' after '/nix/store/' and before '-source'. + """ + # Split the string by '/' and get the last element + last_element = string.split("/")[-1] + # Split the last element by '-' and get the first element + hash_part = last_element.split("-")[0] + # Return the hash part + return hash_part + + +def check_divergence(path: Path) -> None: + p = path.resolve() + + log.info("Absolute web asset path: %s", p) + if not p.is_dir(): + raise FileNotFoundError(p) + + # Get the hash part of the path + gh = get_hash(str(p)) + + log.debug(f"Serving webui asset with hash {gh}") + @functools.cache def asset_path() -> Path: - return Path(__file__).parent / "assets" + path = Path(__file__).parent / "assets" + log.debug("Serving assets from: %s", path) + check_divergence(path) + return path diff --git a/pkgs/clan-cli/clan_cli/webui/error_handlers.py b/pkgs/clan-cli/clan_cli/webui/error_handlers.py new file mode 100644 index 000000000..c7f226d0f --- /dev/null +++ b/pkgs/clan-cli/clan_cli/webui/error_handlers.py @@ -0,0 +1,23 @@ +import logging + +from fastapi import Request, status +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse + +from ..errors import ClanError + +log = logging.getLogger(__name__) + + +def clan_error_handler(request: Request, exc: ClanError) -> JSONResponse: + log.error("ClanError: %s", exc) + detail = [ + { + "loc": [], + "msg": str(exc), + } + ] + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content=jsonable_encoder(dict(detail=detail)), + ) diff --git a/pkgs/clan-cli/clan_cli/webui/routers/flake.py b/pkgs/clan-cli/clan_cli/webui/routers/flake.py index c5f15a970..e1361284b 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/flake.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/flake.py @@ -6,15 +6,15 @@ from fastapi import APIRouter, HTTPException from clan_cli.webui.schemas import FlakeAction, FlakeAttrResponse, FlakeResponse +from ...async_cmd import run from ...nix import nix_command, nix_flake_show -from .utils import run_cmd router = APIRouter() async def get_attrs(url: str) -> list[str]: cmd = nix_flake_show(url) - stdout = await run_cmd(cmd) + stdout = await run(cmd) data: dict[str, dict] = {} try: @@ -45,7 +45,7 @@ async def inspect_flake( # Extract the flake from the given URL # We do this by running 'nix flake prefetch {url} --json' cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"]) - stdout = await run_cmd(cmd) + stdout = await run(cmd) data: dict[str, str] = json.loads(stdout) if data.get("storePath") is None: diff --git a/pkgs/clan-cli/clan_cli/webui/routers/root.py b/pkgs/clan-cli/clan_cli/webui/routers/root.py index e8121d07c..b148270c7 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/root.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/root.py @@ -1,3 +1,4 @@ +import logging import os from mimetypes import guess_type from pathlib import Path @@ -8,6 +9,8 @@ from ..assets import asset_path router = APIRouter() +log = logging.getLogger(__name__) + @router.get("/{path_name:path}") async def root(path_name: str) -> Response: @@ -16,6 +19,7 @@ async def root(path_name: str) -> Response: filename = Path(os.path.normpath(asset_path() / path_name)) if not filename.is_relative_to(asset_path()): + log.error("Prevented directory traversal: %s", filename) # prevent directory traversal return Response(status_code=403) @@ -23,6 +27,7 @@ async def root(path_name: str) -> Response: if filename.suffix == "": filename = filename.with_suffix(".html") if not filename.is_file(): + log.error("File not found: %s", filename) return Response(status_code=404) content_type, _ = guess_type(filename) diff --git a/pkgs/clan-cli/clan_cli/webui/routers/utils.py b/pkgs/clan-cli/clan_cli/webui/routers/utils.py deleted file mode 100644 index dff71d245..000000000 --- a/pkgs/clan-cli/clan_cli/webui/routers/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -import asyncio -import logging -import shlex - -from fastapi import HTTPException, Request, status -from fastapi.encoders import jsonable_encoder -from fastapi.responses import JSONResponse - -log = logging.getLogger(__name__) - - -class NixBuildException(HTTPException): - def __init__(self, msg: str, loc: list = ["body", "flake_attr"]): - detail = [ - { - "loc": loc, - "msg": msg, - "type": "value_error", - } - ] - super().__init__( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail - ) - - -def nix_build_exception_handler( - request: Request, exc: NixBuildException -) -> JSONResponse: - log.error("NixBuildException: %s", exc) - return JSONResponse( - status_code=exc.status_code, - content=jsonable_encoder(dict(detail=exc.detail)), - ) - - -async def run_cmd(cmd: list[str]) -> bytes: - log.debug(f"Running command: {shlex.join(cmd)}") - proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await proc.communicate() - - if proc.returncode != 0: - raise NixBuildException( - f""" -command: {shlex.join(cmd)} -exit code: {proc.returncode} -command output: -{stderr.decode("utf-8")} -""" - ) - return stdout diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index 2cd37b2bd..a56e9a17d 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -1,73 +1,27 @@ -import json import logging from typing import Annotated, Iterator from uuid import UUID -from fastapi import APIRouter, BackgroundTasks, Body, status +from fastapi import APIRouter, Body, status from fastapi.exceptions import HTTPException from fastapi.responses import StreamingResponse from clan_cli.webui.routers.flake import get_attrs -from ...nix import nix_build, nix_eval +from ...task_manager import get_task +from ...vms import create, inspect from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse -from ..task_manager import BaseTask, get_task, register_task -from .utils import run_cmd log = logging.getLogger(__name__) router = APIRouter() -def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]: - return nix_eval( - [ - f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.clan.vm.config" - ] - ) - - -def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]: - return nix_build( - [ - f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.build.vm" - ] - ) - - -class BuildVmTask(BaseTask): - def __init__(self, uuid: UUID, vm: VmConfig) -> None: - super().__init__(uuid) - self.vm = vm - - def run(self) -> None: - try: - self.log.debug(f"BuildVM with uuid {self.uuid} started") - cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url) - - proc = self.run_cmd(cmd) - self.log.debug(f"stdout: {proc.stdout}") - - vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm" - self.log.debug(f"vm_path: {vm_path}") - - self.run_cmd([vm_path]) - self.finished = True - except Exception as e: - self.failed = True - self.finished = True - log.exception(e) - - @router.post("/api/vms/inspect") async def inspect_vm( flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()] ) -> VmInspectResponse: - cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url) - stdout = await run_cmd(cmd) - data = json.loads(stdout) - return VmInspectResponse( - config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data) - ) + config = await inspect.inspect_vm(flake_url, flake_attr) + return VmInspectResponse(config=config) @router.get("/api/vms/{uuid}/status") @@ -84,21 +38,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse: def stream_logs() -> Iterator[str]: task = get_task(uuid) - for proc in task.procs: - if proc.done: - log.debug("stream logs and proc is done") - for line in proc.stderr: - yield line + "\n" - for line in proc.stdout: - yield line + "\n" - continue - while True: - out = proc.output - line = out.get() - if line is None: - log.debug("stream logs and line is None") - break - yield line + yield from task.logs_iter() return StreamingResponse( content=stream_logs(), @@ -107,14 +47,12 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse: @router.post("/api/vms/create") -async def create_vm( - vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks -) -> VmCreateResponse: +async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse: flake_attrs = await get_attrs(vm.flake_url) if vm.flake_attr not in flake_attrs: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Provided attribute '{vm.flake_attr}' does not exist.", ) - uuid = register_task(BuildVmTask, vm) + uuid = create.create_vm(vm) return VmCreateResponse(uuid=str(uuid)) diff --git a/pkgs/clan-cli/clan_cli/webui/schemas.py b/pkgs/clan-cli/clan_cli/webui/schemas.py index c87931a04..578125395 100644 --- a/pkgs/clan-cli/clan_cli/webui/schemas.py +++ b/pkgs/clan-cli/clan_cli/webui/schemas.py @@ -3,6 +3,8 @@ from typing import List from pydantic import BaseModel, Field +from ..vms.inspect import VmConfig + class Status(Enum): ONLINE = "online" @@ -35,15 +37,6 @@ class SchemaResponse(BaseModel): schema_: dict = Field(alias="schema") -class VmConfig(BaseModel): - flake_url: str - flake_attr: str - - cores: int - memory_size: int - graphics: bool - - class VmStatusResponse(BaseModel): returncode: list[int | None] running: bool diff --git a/pkgs/clan-cli/clan_cli/webui/server.py b/pkgs/clan-cli/clan_cli/webui/server.py index 8d67d5a45..f780f9b62 100644 --- a/pkgs/clan-cli/clan_cli/webui/server.py +++ b/pkgs/clan-cli/clan_cli/webui/server.py @@ -1,6 +1,11 @@ import argparse import logging +import multiprocessing as mp +import os +import socket import subprocess +import sys +import syslog import time import urllib.request import webbrowser @@ -90,3 +95,98 @@ def start_server(args: argparse.Namespace) -> None: access_log=args.log_level == "debug", headers=headers, ) + + +# Define a function that takes the path of the file socket as input and returns True if it is served, False otherwise +def is_served(file_socket: Path) -> bool: + # Create a Unix stream socket + client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + # Try to connect to the file socket + try: + client.connect(str(file_socket)) + # Connection succeeded, return True + return True + except OSError: + # Connection failed, return False + return False + finally: + # Close the client socket + client.close() + + +def set_out_to_syslog() -> None: # type: ignore + # Define some constants for convenience + log_levels = { + "emerg": syslog.LOG_EMERG, + "alert": syslog.LOG_ALERT, + "crit": syslog.LOG_CRIT, + "err": syslog.LOG_ERR, + "warning": syslog.LOG_WARNING, + "notice": syslog.LOG_NOTICE, + "info": syslog.LOG_INFO, + "debug": syslog.LOG_DEBUG, + } + facility = syslog.LOG_USER # Use user facility for custom applications + + # Open a connection to the system logger + syslog.openlog("clan-cli", 0, facility) # Use "myapp" as the prefix for messages + + # Define a custom write function that sends messages to syslog + def write(message: str) -> int: + # Strip the newline character from the message + message = message.rstrip("\n") + # Check if the message is not empty + if message: + # Send the message to syslog with the appropriate level + if message.startswith("ERROR:"): + # Use error level for messages that start with "ERROR:" + syslog.syslog(log_levels["err"], message) + else: + # Use info level for other messages + syslog.syslog(log_levels["info"], message) + return 0 + + # Assign the custom write function to sys.stdout and sys.stderr + setattr(sys.stdout, "write", write) + setattr(sys.stderr, "write", write) + + # Define a dummy flush function to prevent errors + def flush() -> None: + pass + + # Assign the dummy flush function to sys.stdout and sys.stderr + setattr(sys.stdout, "flush", flush) + setattr(sys.stderr, "flush", flush) + + +def _run_socketfile(socket_file: Path, debug: bool) -> None: + set_out_to_syslog() + run( + "clan_cli.webui.app:app", + uds=str(socket_file), + access_log=debug, + reload=False, + log_level="debug" if debug else "info", + ) + + +@contextmanager +def api_server(debug: bool) -> Iterator[Path]: + runtime_dir = os.getenv("XDG_RUNTIME_DIR") + if runtime_dir is None: + raise RuntimeError("XDG_RUNTIME_DIR not set") + socket_path = Path(runtime_dir) / "clan.sock" + socket_path = socket_path.resolve() + + log.debug("Socketfile lies at %s", socket_path) + + if not is_served(socket_path): + log.debug("Starting api server...") + mp.set_start_method(method="spawn") + proc = mp.Process(target=_run_socketfile, args=(socket_path, debug)) + proc.start() + else: + log.info("Api server is already running on %s", socket_path) + + yield socket_path + proc.terminate() diff --git a/pkgs/clan-cli/clan_cli/webui/task_manager.py b/pkgs/clan-cli/clan_cli/webui/task_manager.py deleted file mode 100644 index 21374cb55..000000000 --- a/pkgs/clan-cli/clan_cli/webui/task_manager.py +++ /dev/null @@ -1,119 +0,0 @@ -import logging -import os -import queue -import select -import shlex -import subprocess -import threading -from typing import Any -from uuid import UUID, uuid4 - - -class CmdState: - def __init__(self, proc: subprocess.Popen) -> None: - global LOOP - self.proc: subprocess.Popen = proc - self.stdout: list[str] = [] - self.stderr: list[str] = [] - self.output: queue.SimpleQueue = queue.SimpleQueue() - self.returncode: int | None = None - self.done: bool = False - - -class BaseTask(threading.Thread): - def __init__(self, uuid: UUID) -> None: - # calling parent class constructor - threading.Thread.__init__(self) - - # constructor - self.uuid: UUID = uuid - self.log = logging.getLogger(__name__) - self.procs: list[CmdState] = [] - self.failed: bool = False - self.finished: bool = False - - def run(self) -> None: - self.finished = True - - def run_cmd(self, cmd: list[str]) -> CmdState: - cwd = os.getcwd() - self.log.debug(f"Working directory: {cwd}") - self.log.debug(f"Running command: {shlex.join(cmd)}") - p = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - encoding="utf-8", - # shell=True, - cwd=cwd, - ) - self.procs.append(CmdState(p)) - p_state = self.procs[-1] - - while p.poll() is None: - # Check if stderr is ready to be read from - rlist, _, _ = select.select([p.stderr, p.stdout], [], [], 0) - if p.stderr in rlist: - assert p.stderr is not None - line = p.stderr.readline() - if line != "": - p_state.stderr.append(line.strip("\n")) - self.log.debug(f"stderr: {line}") - p_state.output.put(line) - - if p.stdout in rlist: - assert p.stdout is not None - line = p.stdout.readline() - if line != "": - p_state.stdout.append(line.strip("\n")) - self.log.debug(f"stdout: {line}") - p_state.output.put(line) - - p_state.returncode = p.returncode - p_state.output.put(None) - p_state.done = True - - if p.returncode != 0: - raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}") - - self.log.debug("Successfully ran command") - return p_state - - -class TaskPool: - def __init__(self) -> None: - self.lock: threading.RLock = threading.RLock() - self.pool: dict[UUID, BaseTask] = {} - - def __getitem__(self, uuid: UUID) -> BaseTask: - with self.lock: - return self.pool[uuid] - - def __setitem__(self, uuid: UUID, task: BaseTask) -> None: - with self.lock: - if uuid in self.pool: - raise KeyError(f"Task with uuid {uuid} already exists") - if type(uuid) is not UUID: - raise TypeError("uuid must be of type UUID") - self.pool[uuid] = task - - -POOL: TaskPool = TaskPool() - - -def get_task(uuid: UUID) -> BaseTask: - global POOL - return POOL[uuid] - - -def register_task(task: type, *args: Any) -> UUID: - global POOL - if not issubclass(task, BaseTask): - raise TypeError("task must be a subclass of BaseTask") - - uuid = uuid4() - - inst_task = task(uuid, *args) - POOL[uuid] = inst_task - inst_task.start() - return uuid diff --git a/pkgs/clan-cli/default.nix b/pkgs/clan-cli/default.nix index ac588160d..e252e3c61 100644 --- a/pkgs/clan-cli/default.nix +++ b/pkgs/clan-cli/default.nix @@ -29,6 +29,7 @@ , copyDesktopItems , qemu , gnupg +, e2fsprogs }: let @@ -63,6 +64,7 @@ let sops git qemu + e2fsprogs ]; runtimeDependenciesAsSet = builtins.listToAttrs (builtins.map (p: lib.nameValuePair (lib.getName p.name) p) runtimeDependencies); diff --git a/pkgs/clan-cli/tests/test_flake_api.py b/pkgs/clan-cli/tests/test_flake_api.py index 2fa65d281..f44a94228 100644 --- a/pkgs/clan-cli/tests/test_flake_api.py +++ b/pkgs/clan-cli/tests/test_flake_api.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import pytest @@ -28,3 +29,23 @@ def test_inspect_err(api: TestClient) -> None: data = response.json() print("Data: ", data) assert data.get("detail") + + +@pytest.mark.impure +def test_inspect_flake(api: TestClient, test_flake_with_core: Path) -> None: + params = {"url": str(test_flake_with_core)} + response = api.get( + "/api/flake", + params=params, + ) + assert response.status_code == 200, "Failed to inspect vm" + data = response.json() + print("Data: ", json.dumps(data, indent=2)) + assert data.get("content") is not None + actions = data.get("actions") + assert actions is not None + assert len(actions) == 2 + assert actions[0].get("id") == "vms/inspect" + assert actions[0].get("uri") == "api/vms/inspect" + assert actions[1].get("id") == "vms/create" + assert actions[1].get("uri") == "api/vms/create" diff --git a/pkgs/clan-cli/tests/test_vms_api.py b/pkgs/clan-cli/tests/test_vms_api.py index 5bbc3c6d8..13cfe0200 100644 --- a/pkgs/clan-cli/tests/test_vms_api.py +++ b/pkgs/clan-cli/tests/test_vms_api.py @@ -6,24 +6,6 @@ from api import TestClient from httpx import SyncByteStream -def is_running_in_ci() -> bool: - # Check if pytest is running in GitHub Actions - if os.getenv("GITHUB_ACTIONS") == "true": - print("Running on GitHub Actions") - return True - - # Check if pytest is running in Travis CI - if os.getenv("TRAVIS") == "true": - print("Running on Travis CI") - return True - - # Check if pytest is running in Circle CI - if os.getenv("CIRCLECI") == "true": - print("Running on Circle CI") - return True - return False - - @pytest.mark.impure def test_inspect(api: TestClient, test_flake_with_core: Path) -> None: response = api.post( @@ -49,10 +31,9 @@ def test_incorrect_uuid(api: TestClient) -> None: assert response.status_code == 422, "Failed to get vm status" +@pytest.mark.skipif(not os.path.exists("/dev/kvm"), reason="Requires KVM") @pytest.mark.impure def test_create(api: TestClient, test_flake_with_core: Path) -> None: - if is_running_in_ci(): - pytest.skip("Skipping test in CI. As it requires KVM") print(f"flake_url: {test_flake_with_core} ") response = api.post( "/api/vms/create", @@ -74,20 +55,11 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None: assert response.status_code == 200, "Failed to get vm status" response = api.get(f"/api/vms/{uuid}/logs") - print("=========FLAKE LOGS==========") - assert isinstance(response.stream, SyncByteStream) - for line in response.stream: - assert line != b"", "Failed to get vm logs" - print(line.decode("utf-8"), end="") - print("=========END LOGS==========") - assert response.status_code == 200, "Failed to get vm logs" - - response = api.get(f"/api/vms/{uuid}/logs") - assert isinstance(response.stream, SyncByteStream) print("=========VM LOGS==========") + assert isinstance(response.stream, SyncByteStream) for line in response.stream: assert line != b"", "Failed to get vm logs" - print(line.decode("utf-8"), end="") + print(line.decode("utf-8")) print("=========END LOGS==========") assert response.status_code == 200, "Failed to get vm logs" diff --git a/pkgs/clan-cli/tests/test_vms_cli.py b/pkgs/clan-cli/tests/test_vms_cli.py new file mode 100644 index 000000000..8b365d2a4 --- /dev/null +++ b/pkgs/clan-cli/tests/test_vms_cli.py @@ -0,0 +1,22 @@ +import os +from pathlib import Path + +import pytest +from cli import Cli + +no_kvm = not os.path.exists("/dev/kvm") + + +@pytest.mark.impure +def test_inspect(test_flake_with_core: Path, capsys: pytest.CaptureFixture) -> None: + cli = Cli() + cli.run(["vms", "inspect", "vm1"]) + out = capsys.readouterr() # empty the buffer + assert "Cores" in out.out + + +@pytest.mark.skipif(no_kvm, reason="Requires KVM") +@pytest.mark.impure +def test_create(test_flake_with_core: Path) -> None: + cli = Cli() + cli.run(["vms", "create", "vm1"])