Improving endpoint

This commit is contained in:
Qubasa
2023-09-25 20:09:27 +02:00
committed by Mic92
parent d16bb5db26
commit f6c8b963c1
5 changed files with 156 additions and 50 deletions

View File

@@ -1,9 +1,9 @@
import datetime import datetime
import logging import logging
from typing import Any
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
grey = "\x1b[38;20m" grey = "\x1b[38;20m"
yellow = "\x1b[33;20m" yellow = "\x1b[33;20m"
red = "\x1b[31;20m" red = "\x1b[31;20m"
@@ -11,7 +11,8 @@ class CustomFormatter(logging.Formatter):
green = "\u001b[32m" green = "\u001b[32m"
blue = "\u001b[34m" blue = "\u001b[34m"
def format_str(color): @staticmethod
def format_str(color: str) -> str:
reset = "\x1b[0m" reset = "\x1b[0m"
return f"{color}%(levelname)s{reset}:(%(filename)s:%(lineno)d): %(message)s" return f"{color}%(levelname)s{reset}:(%(filename)s:%(lineno)d): %(message)s"
@@ -20,24 +21,23 @@ class CustomFormatter(logging.Formatter):
logging.INFO: format_str(green), logging.INFO: format_str(green),
logging.WARNING: format_str(yellow), logging.WARNING: format_str(yellow),
logging.ERROR: format_str(red), logging.ERROR: format_str(red),
logging.CRITICAL: format_str(bold_red) logging.CRITICAL: format_str(bold_red),
} }
def formatTime(self, record,datefmt=None): def format_time(self, record: Any, datefmt: Any = None) -> str:
now = datetime.datetime.now() now = datetime.datetime.now()
now = now.strftime("%H:%M:%S") now = now.strftime("%H:%M:%S")
return now return now
def format(self, record): def format(self, record: Any) -> str:
log_fmt = self.FORMATS.get(record.levelno) log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt) formatter = logging.Formatter(log_fmt)
formatter.formatTime = self.formatTime formatter.formatTime = self.format_time
return formatter.format(record) return formatter.format(record)
def register(level): def register(level: Any) -> None:
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(level) ch.setLevel(level)
ch.setFormatter(CustomFormatter()) ch.setFormatter(CustomFormatter())
logging.basicConfig(level=level, handlers=[ch]) logging.basicConfig(level=level, handlers=[ch])

View File

@@ -1,10 +1,11 @@
import logging
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
import logging
from .. import custom_logger from .. import custom_logger
from .assets import asset_path from .assets import asset_path
from .routers import flake, health, machines, root, vms from .routers import flake, health, machines, root, vms
@@ -39,7 +40,7 @@ def setup_app() -> FastAPI:
return app return app
#TODO: How do I get the log level from the command line in here? # TODO: How do I get the log level from the command line in here?
custom_logger.register(logging.DEBUG) custom_logger.register(logging.DEBUG)
app = setup_app() app = setup_app()

View File

@@ -2,16 +2,27 @@ import asyncio
import json import json
import logging import logging
import os import os
import select
import queue
import shlex import shlex
import uuid import subprocess
import threading
from typing import Annotated, AsyncIterator from typing import Annotated, AsyncIterator
from uuid import UUID, uuid4
from fastapi import APIRouter, Body, FastAPI, HTTPException, Request, status, BackgroundTasks from fastapi import (
APIRouter,
BackgroundTasks,
Body,
HTTPException,
Request,
status,
)
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from ...nix import nix_build, nix_eval from ...nix import nix_build, nix_eval
from ..schemas import VmConfig, VmInspectResponse, VmCreateResponse from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
# Logging setup # Logging setup
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -37,8 +48,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
) )
@router.post("/api/vms/inspect") @router.post("/api/vms/inspect")
async def inspect_vm( async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()] flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
@@ -68,9 +77,8 @@ command output:
) )
class NixBuildException(HTTPException): class NixBuildException(HTTPException):
def __init__(self, uuid: uuid.UUID, msg: str,loc: list = ["body", "flake_attr"]): def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
self.uuid = uuid self.uuid = uuid
detail = [ detail = [
{ {
@@ -85,74 +93,161 @@ class NixBuildException(HTTPException):
) )
class ProcessState:
import threading def __init__(self, proc: subprocess.Popen):
import subprocess self.proc: subprocess.Process = proc
import uuid self.stdout: list[str] = []
self.stderr: list[str] = []
self.returncode: int | None = None
self.done: bool = False
class BuildVM(threading.Thread): class BuildVM(threading.Thread):
def __init__(self, vm: VmConfig, uuid: uuid.UUID): def __init__(self, vm: VmConfig, uuid: UUID):
# calling parent class constructor # calling parent class constructor
threading.Thread.__init__(self) threading.Thread.__init__(self)
# constructor # constructor
self.vm: VmConfig = vm self.vm: VmConfig = vm
self.uuid: uuid.UUID = uuid self.uuid: UUID = uuid
self.log = logging.getLogger(__name__) self.log = logging.getLogger(__name__)
self.process: subprocess.Popen = None self.procs: list[ProcessState] = []
self.failed: bool = False
self.finished: bool = False
def run(self): def run(self):
self.log.debug(f"BuildVM with uuid {self.uuid} started") try:
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url) self.log.debug(f"BuildVM with uuid {self.uuid} started")
(out, err) = self.run_cmd(cmd) cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
vm_path = f'{out.strip()}/bin/run-nixos-vm'
self.log.debug(f"vm_path: {vm_path}") proc = self.run_cmd(cmd)
out = proc.stdout
self.log.debug(f"out: {out}")
(out, err) = self.run_cmd(vm_path) vm_path = f"{''.join(out)}/bin/run-nixos-vm"
self.log.debug(f"vm_path: {vm_path}")
self.run_cmd(vm_path)
self.finished = True
except Exception as e:
self.failed = True
self.finished = True
log.exception(e)
def run_cmd(self, cmd: list[str]): def run_cmd(self, cmd: list[str]) -> ProcessState:
cwd=os.getcwd() cwd = os.getcwd()
log.debug(f"Working directory: {cwd}") log.debug(f"Working directory: {cwd}")
log.debug(f"Running command: {shlex.join(cmd)}") log.debug(f"Running command: {shlex.join(cmd)}")
self.process = subprocess.Popen( process = subprocess.Popen(
cmd, cmd,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
encoding="utf-8", encoding="utf-8",
cwd=cwd, cwd=cwd,
) )
state = ProcessState(process)
self.procs.append(state)
self.process.wait() while process.poll() is None:
if self.process.returncode != 0: # Check if stderr is ready to be read from
raise NixBuildException(self.uuid, f"Failed to run command: {shlex.join(cmd)}") rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
if process.stderr in rlist:
line = process.stderr.readline()
state.stderr.append(line)
if process.stdout in rlist:
line = process.stdout.readline()
state.stdout.append(line)
log.info("Successfully ran command") state.returncode = process.returncode
return (self.process.stdout, self.process.stderr) state.done = True
POOL: dict[uuid.UUID, BuildVM] = {} if process.returncode != 0:
raise NixBuildException(
self.uuid, f"Failed to run command: {shlex.join(cmd)}"
)
log.debug("Successfully ran command")
return state
class VmTaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BuildVM] = {}
def __getitem__(self, uuid: str | UUID) -> BuildVM:
with self.lock:
if type(uuid) is UUID:
return self.pool[uuid]
else:
uuid = UUID(uuid)
return self.pool[uuid]
def __setitem__(self, uuid: UUID, vm: BuildVM) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"VM with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = vm
POOL: VmTaskPool = VmTaskPool()
def nix_build_exception_handler( def nix_build_exception_handler(
request: Request, exc: NixBuildException request: Request, exc: NixBuildException
) -> JSONResponse: ) -> JSONResponse:
log.error("NixBuildException: %s", exc) log.error("NixBuildException: %s", exc)
del POOL[exc.uuid] # del POOL[exc.uuid]
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)), content=jsonable_encoder(dict(detail=exc.detail)),
) )
@router.get("/api/vms/{uuid}/status")
async def get_status(uuid: str) -> VmStatusResponse:
global POOL
handle = POOL[uuid]
if handle.process.poll() is None:
return VmStatusResponse(running=True, status=0)
else:
return VmStatusResponse(running=False, status=handle.process.returncode)
@router.get("/api/vms/{uuid}/logs")
async def get_logs(uuid: str) -> StreamingResponse:
async def stream_logs() -> AsyncIterator[str]:
global POOL
handle = POOL[uuid]
for proc in handle.procs.values():
while True:
if proc.stdout.empty() and proc.stderr.empty() and not proc.done:
await asyncio.sleep(0.1)
continue
if proc.stdout.empty() and proc.stderr.empty() and proc.done:
break
for line in proc.stdout:
yield line
for line in proc.stderr:
yield line
return StreamingResponse(
content=stream_logs(),
media_type="text/plain",
)
@router.post("/api/vms/create") @router.post("/api/vms/create")
async def create_vm(vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks) -> StreamingResponse: async def create_vm(
handle_id = uuid.uuid4() vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
handle = BuildVM(vm, handle_id) ) -> VmCreateResponse:
global POOL
uuid = uuid4()
handle = BuildVM(vm, uuid)
handle.start() handle.start()
POOL[handle_id] = handle POOL[uuid] = handle
return VmCreateResponse(uuid=str(handle_id)) return VmCreateResponse(uuid=str(uuid))

View File

@@ -43,9 +43,16 @@ class VmConfig(BaseModel):
memory_size: int memory_size: int
graphics: bool graphics: bool
class VmStatusResponse(BaseModel):
status: int
running: bool
class VmCreateResponse(BaseModel): class VmCreateResponse(BaseModel):
uuid: str uuid: str
class VmInspectResponse(BaseModel): class VmInspectResponse(BaseModel):
config: VmConfig config: VmConfig

View File

@@ -1,4 +1,5 @@
import argparse import argparse
import logging
import subprocess import subprocess
import time import time
import urllib.request import urllib.request
@@ -11,6 +12,8 @@ from typing import Iterator
# XXX: can we dynamically load this using nix develop? # XXX: can we dynamically load this using nix develop?
from uvicorn import run from uvicorn import run
log = logging.getLogger(__name__)
def defer_open_browser(base_url: str) -> None: def defer_open_browser(base_url: str) -> None:
for i in range(5): for i in range(5):
@@ -24,7 +27,7 @@ def defer_open_browser(base_url: str) -> None:
@contextmanager @contextmanager
def spawn_node_dev_server(host: str, port: int) -> Iterator[None]: def spawn_node_dev_server(host: str, port: int) -> Iterator[None]:
logger.info("Starting node dev server...") log.info("Starting node dev server...")
path = Path(__file__).parent.parent.parent.parent / "ui" path = Path(__file__).parent.parent.parent.parent / "ui"
with subprocess.Popen( with subprocess.Popen(
[ [