From f6c8b963c195f76100da3800ca375b6295bd7a8c Mon Sep 17 00:00:00 2001 From: Qubasa Date: Mon, 25 Sep 2023 20:09:27 +0200 Subject: [PATCH] Improving endpoint --- pkgs/clan-cli/clan_cli/custom_logger.py | 16 +- pkgs/clan-cli/clan_cli/webui/app.py | 7 +- pkgs/clan-cli/clan_cli/webui/routers/vms.py | 171 +++++++++++++++----- pkgs/clan-cli/clan_cli/webui/schemas.py | 7 + pkgs/clan-cli/clan_cli/webui/server.py | 5 +- 5 files changed, 156 insertions(+), 50 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/custom_logger.py b/pkgs/clan-cli/clan_cli/custom_logger.py index 8566b9516..b16fc8904 100644 --- a/pkgs/clan-cli/clan_cli/custom_logger.py +++ b/pkgs/clan-cli/clan_cli/custom_logger.py @@ -1,9 +1,9 @@ import datetime import logging +from typing import Any class CustomFormatter(logging.Formatter): - grey = "\x1b[38;20m" yellow = "\x1b[33;20m" red = "\x1b[31;20m" @@ -11,7 +11,8 @@ class CustomFormatter(logging.Formatter): green = "\u001b[32m" blue = "\u001b[34m" - def format_str(color): + @staticmethod + def format_str(color: str) -> str: reset = "\x1b[0m" return f"{color}%(levelname)s{reset}:(%(filename)s:%(lineno)d): %(message)s" @@ -20,24 +21,23 @@ class CustomFormatter(logging.Formatter): logging.INFO: format_str(green), logging.WARNING: format_str(yellow), logging.ERROR: format_str(red), - logging.CRITICAL: format_str(bold_red) + logging.CRITICAL: format_str(bold_red), } - def formatTime(self, record,datefmt=None): + def format_time(self, record: Any, datefmt: Any = None) -> str: now = datetime.datetime.now() now = now.strftime("%H:%M:%S") return now - def format(self, record): + def format(self, record: Any) -> str: log_fmt = self.FORMATS.get(record.levelno) formatter = logging.Formatter(log_fmt) - formatter.formatTime = self.formatTime + formatter.formatTime = self.format_time return formatter.format(record) -def register(level): +def register(level: Any) -> None: ch = logging.StreamHandler() ch.setLevel(level) ch.setFormatter(CustomFormatter()) logging.basicConfig(level=level, handlers=[ch]) - diff --git a/pkgs/clan-cli/clan_cli/webui/app.py b/pkgs/clan-cli/clan_cli/webui/app.py index 0c03bccd6..bd586789d 100644 --- a/pkgs/clan-cli/clan_cli/webui/app.py +++ b/pkgs/clan-cli/clan_cli/webui/app.py @@ -1,10 +1,11 @@ +import logging + from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.routing import APIRoute from fastapi.staticfiles import StaticFiles -import logging -from .. import custom_logger +from .. import custom_logger from .assets import asset_path from .routers import flake, health, machines, root, vms @@ -39,7 +40,7 @@ def setup_app() -> FastAPI: return app -#TODO: How do I get the log level from the command line in here? +# TODO: How do I get the log level from the command line in here? custom_logger.register(logging.DEBUG) app = setup_app() diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index 9ed4579e4..47566fadd 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -2,16 +2,27 @@ import asyncio import json import logging import os +import select +import queue import shlex -import uuid +import subprocess +import threading from typing import Annotated, AsyncIterator +from uuid import UUID, uuid4 -from fastapi import APIRouter, Body, FastAPI, HTTPException, Request, status, BackgroundTasks +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + HTTPException, + Request, + status, +) from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse, StreamingResponse from ...nix import nix_build, nix_eval -from ..schemas import VmConfig, VmInspectResponse, VmCreateResponse +from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse # Logging setup log = logging.getLogger(__name__) @@ -37,8 +48,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]: ) - - @router.post("/api/vms/inspect") async def inspect_vm( flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()] @@ -68,9 +77,8 @@ command output: ) - class NixBuildException(HTTPException): - def __init__(self, uuid: uuid.UUID, msg: str,loc: list = ["body", "flake_attr"]): + def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]): self.uuid = uuid detail = [ { @@ -85,74 +93,161 @@ class NixBuildException(HTTPException): ) - -import threading -import subprocess -import uuid - +class ProcessState: + def __init__(self, proc: subprocess.Popen): + self.proc: subprocess.Process = proc + self.stdout: list[str] = [] + self.stderr: list[str] = [] + self.returncode: int | None = None + self.done: bool = False class BuildVM(threading.Thread): - def __init__(self, vm: VmConfig, uuid: uuid.UUID): + def __init__(self, vm: VmConfig, uuid: UUID): # calling parent class constructor threading.Thread.__init__(self) # constructor self.vm: VmConfig = vm - self.uuid: uuid.UUID = uuid + self.uuid: UUID = uuid self.log = logging.getLogger(__name__) - self.process: subprocess.Popen = None + self.procs: list[ProcessState] = [] + self.failed: bool = False + self.finished: bool = False def run(self): - self.log.debug(f"BuildVM with uuid {self.uuid} started") + try: - cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url) - (out, err) = self.run_cmd(cmd) - vm_path = f'{out.strip()}/bin/run-nixos-vm' + self.log.debug(f"BuildVM with uuid {self.uuid} started") + cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url) - self.log.debug(f"vm_path: {vm_path}") + proc = self.run_cmd(cmd) + out = proc.stdout + self.log.debug(f"out: {out}") - (out, err) = self.run_cmd(vm_path) + vm_path = f"{''.join(out)}/bin/run-nixos-vm" + self.log.debug(f"vm_path: {vm_path}") + self.run_cmd(vm_path) + self.finished = True + except Exception as e: + self.failed = True + self.finished = True + log.exception(e) - def run_cmd(self, cmd: list[str]): - cwd=os.getcwd() + def run_cmd(self, cmd: list[str]) -> ProcessState: + cwd = os.getcwd() log.debug(f"Working directory: {cwd}") log.debug(f"Running command: {shlex.join(cmd)}") - self.process = subprocess.Popen( + process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8", cwd=cwd, ) + state = ProcessState(process) + self.procs.append(state) - self.process.wait() - if self.process.returncode != 0: - raise NixBuildException(self.uuid, f"Failed to run command: {shlex.join(cmd)}") + while process.poll() is None: + # Check if stderr is ready to be read from + rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0) + if process.stderr in rlist: + line = process.stderr.readline() + state.stderr.append(line) + if process.stdout in rlist: + line = process.stdout.readline() + state.stdout.append(line) - log.info("Successfully ran command") - return (self.process.stdout, self.process.stderr) + state.returncode = process.returncode + state.done = True -POOL: dict[uuid.UUID, BuildVM] = {} + if process.returncode != 0: + raise NixBuildException( + self.uuid, f"Failed to run command: {shlex.join(cmd)}" + ) + + log.debug("Successfully ran command") + return state + + +class VmTaskPool: + def __init__(self) -> None: + self.lock: threading.RLock = threading.RLock() + self.pool: dict[UUID, BuildVM] = {} + + def __getitem__(self, uuid: str | UUID) -> BuildVM: + with self.lock: + if type(uuid) is UUID: + return self.pool[uuid] + else: + uuid = UUID(uuid) + return self.pool[uuid] + + def __setitem__(self, uuid: UUID, vm: BuildVM) -> None: + with self.lock: + if uuid in self.pool: + raise KeyError(f"VM with uuid {uuid} already exists") + if type(uuid) is not UUID: + raise TypeError("uuid must be of type UUID") + self.pool[uuid] = vm + + +POOL: VmTaskPool = VmTaskPool() def nix_build_exception_handler( request: Request, exc: NixBuildException ) -> JSONResponse: log.error("NixBuildException: %s", exc) - del POOL[exc.uuid] + # del POOL[exc.uuid] return JSONResponse( status_code=exc.status_code, content=jsonable_encoder(dict(detail=exc.detail)), ) +@router.get("/api/vms/{uuid}/status") +async def get_status(uuid: str) -> VmStatusResponse: + global POOL + handle = POOL[uuid] + + if handle.process.poll() is None: + return VmStatusResponse(running=True, status=0) + else: + return VmStatusResponse(running=False, status=handle.process.returncode) + + + +@router.get("/api/vms/{uuid}/logs") +async def get_logs(uuid: str) -> StreamingResponse: + async def stream_logs() -> AsyncIterator[str]: + global POOL + handle = POOL[uuid] + for proc in handle.procs.values(): + while True: + if proc.stdout.empty() and proc.stderr.empty() and not proc.done: + await asyncio.sleep(0.1) + continue + if proc.stdout.empty() and proc.stderr.empty() and proc.done: + break + for line in proc.stdout: + yield line + for line in proc.stderr: + yield line + + return StreamingResponse( + content=stream_logs(), + media_type="text/plain", + ) + + @router.post("/api/vms/create") -async def create_vm(vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks) -> StreamingResponse: - handle_id = uuid.uuid4() - handle = BuildVM(vm, handle_id) +async def create_vm( + vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks +) -> VmCreateResponse: + global POOL + uuid = uuid4() + handle = BuildVM(vm, uuid) handle.start() - POOL[handle_id] = handle - return VmCreateResponse(uuid=str(handle_id)) - - + POOL[uuid] = handle + return VmCreateResponse(uuid=str(uuid)) diff --git a/pkgs/clan-cli/clan_cli/webui/schemas.py b/pkgs/clan-cli/clan_cli/webui/schemas.py index 8ee819ce9..874e18aba 100644 --- a/pkgs/clan-cli/clan_cli/webui/schemas.py +++ b/pkgs/clan-cli/clan_cli/webui/schemas.py @@ -43,9 +43,16 @@ class VmConfig(BaseModel): memory_size: int graphics: bool + +class VmStatusResponse(BaseModel): + status: int + running: bool + + class VmCreateResponse(BaseModel): uuid: str + class VmInspectResponse(BaseModel): config: VmConfig diff --git a/pkgs/clan-cli/clan_cli/webui/server.py b/pkgs/clan-cli/clan_cli/webui/server.py index 800cdab5d..8d67d5a45 100644 --- a/pkgs/clan-cli/clan_cli/webui/server.py +++ b/pkgs/clan-cli/clan_cli/webui/server.py @@ -1,4 +1,5 @@ import argparse +import logging import subprocess import time import urllib.request @@ -11,6 +12,8 @@ from typing import Iterator # XXX: can we dynamically load this using nix develop? from uvicorn import run +log = logging.getLogger(__name__) + def defer_open_browser(base_url: str) -> None: for i in range(5): @@ -24,7 +27,7 @@ def defer_open_browser(base_url: str) -> None: @contextmanager def spawn_node_dev_server(host: str, port: int) -> Iterator[None]: - logger.info("Starting node dev server...") + log.info("Starting node dev server...") path = Path(__file__).parent.parent.parent.parent / "ui" with subprocess.Popen( [