Improving endpoint
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
|
||||
grey = "\x1b[38;20m"
|
||||
yellow = "\x1b[33;20m"
|
||||
red = "\x1b[31;20m"
|
||||
@@ -11,7 +11,8 @@ class CustomFormatter(logging.Formatter):
|
||||
green = "\u001b[32m"
|
||||
blue = "\u001b[34m"
|
||||
|
||||
def format_str(color):
|
||||
@staticmethod
|
||||
def format_str(color: str) -> str:
|
||||
reset = "\x1b[0m"
|
||||
return f"{color}%(levelname)s{reset}:(%(filename)s:%(lineno)d): %(message)s"
|
||||
|
||||
@@ -20,24 +21,23 @@ class CustomFormatter(logging.Formatter):
|
||||
logging.INFO: format_str(green),
|
||||
logging.WARNING: format_str(yellow),
|
||||
logging.ERROR: format_str(red),
|
||||
logging.CRITICAL: format_str(bold_red)
|
||||
logging.CRITICAL: format_str(bold_red),
|
||||
}
|
||||
|
||||
def formatTime(self, record,datefmt=None):
|
||||
def format_time(self, record: Any, datefmt: Any = None) -> str:
|
||||
now = datetime.datetime.now()
|
||||
now = now.strftime("%H:%M:%S")
|
||||
return now
|
||||
|
||||
def format(self, record):
|
||||
def format(self, record: Any) -> str:
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
formatter.formatTime = self.formatTime
|
||||
formatter.formatTime = self.format_time
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
def register(level):
|
||||
def register(level: Any) -> None:
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(level)
|
||||
ch.setFormatter(CustomFormatter())
|
||||
logging.basicConfig(level=level, handlers=[ch])
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import logging
|
||||
|
||||
from .. import custom_logger
|
||||
from .. import custom_logger
|
||||
from .assets import asset_path
|
||||
from .routers import flake, health, machines, root, vms
|
||||
|
||||
@@ -39,7 +40,7 @@ def setup_app() -> FastAPI:
|
||||
return app
|
||||
|
||||
|
||||
#TODO: How do I get the log level from the command line in here?
|
||||
# TODO: How do I get the log level from the command line in here?
|
||||
custom_logger.register(logging.DEBUG)
|
||||
app = setup_app()
|
||||
|
||||
|
||||
@@ -2,16 +2,27 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import select
|
||||
import queue
|
||||
import shlex
|
||||
import uuid
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Annotated, AsyncIterator
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Body, FastAPI, HTTPException, Request, status, BackgroundTasks
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Body,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from ...nix import nix_build, nix_eval
|
||||
from ..schemas import VmConfig, VmInspectResponse, VmCreateResponse
|
||||
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
|
||||
|
||||
# Logging setup
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -37,8 +48,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/api/vms/inspect")
|
||||
async def inspect_vm(
|
||||
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
|
||||
@@ -68,9 +77,8 @@ command output:
|
||||
)
|
||||
|
||||
|
||||
|
||||
class NixBuildException(HTTPException):
|
||||
def __init__(self, uuid: uuid.UUID, msg: str,loc: list = ["body", "flake_attr"]):
|
||||
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
|
||||
self.uuid = uuid
|
||||
detail = [
|
||||
{
|
||||
@@ -85,74 +93,161 @@ class NixBuildException(HTTPException):
|
||||
)
|
||||
|
||||
|
||||
|
||||
import threading
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
class ProcessState:
|
||||
def __init__(self, proc: subprocess.Popen):
|
||||
self.proc: subprocess.Process = proc
|
||||
self.stdout: list[str] = []
|
||||
self.stderr: list[str] = []
|
||||
self.returncode: int | None = None
|
||||
self.done: bool = False
|
||||
|
||||
class BuildVM(threading.Thread):
|
||||
def __init__(self, vm: VmConfig, uuid: uuid.UUID):
|
||||
def __init__(self, vm: VmConfig, uuid: UUID):
|
||||
# calling parent class constructor
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
# constructor
|
||||
self.vm: VmConfig = vm
|
||||
self.uuid: uuid.UUID = uuid
|
||||
self.uuid: UUID = uuid
|
||||
self.log = logging.getLogger(__name__)
|
||||
self.process: subprocess.Popen = None
|
||||
self.procs: list[ProcessState] = []
|
||||
self.failed: bool = False
|
||||
self.finished: bool = False
|
||||
|
||||
def run(self):
|
||||
self.log.debug(f"BuildVM with uuid {self.uuid} started")
|
||||
try:
|
||||
|
||||
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
|
||||
(out, err) = self.run_cmd(cmd)
|
||||
vm_path = f'{out.strip()}/bin/run-nixos-vm'
|
||||
self.log.debug(f"BuildVM with uuid {self.uuid} started")
|
||||
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
|
||||
|
||||
self.log.debug(f"vm_path: {vm_path}")
|
||||
proc = self.run_cmd(cmd)
|
||||
out = proc.stdout
|
||||
self.log.debug(f"out: {out}")
|
||||
|
||||
(out, err) = self.run_cmd(vm_path)
|
||||
vm_path = f"{''.join(out)}/bin/run-nixos-vm"
|
||||
self.log.debug(f"vm_path: {vm_path}")
|
||||
self.run_cmd(vm_path)
|
||||
|
||||
self.finished = True
|
||||
except Exception as e:
|
||||
self.failed = True
|
||||
self.finished = True
|
||||
log.exception(e)
|
||||
|
||||
def run_cmd(self, cmd: list[str]):
|
||||
cwd=os.getcwd()
|
||||
def run_cmd(self, cmd: list[str]) -> ProcessState:
|
||||
cwd = os.getcwd()
|
||||
log.debug(f"Working directory: {cwd}")
|
||||
log.debug(f"Running command: {shlex.join(cmd)}")
|
||||
self.process = subprocess.Popen(
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
cwd=cwd,
|
||||
)
|
||||
state = ProcessState(process)
|
||||
self.procs.append(state)
|
||||
|
||||
self.process.wait()
|
||||
if self.process.returncode != 0:
|
||||
raise NixBuildException(self.uuid, f"Failed to run command: {shlex.join(cmd)}")
|
||||
while process.poll() is None:
|
||||
# Check if stderr is ready to be read from
|
||||
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
|
||||
if process.stderr in rlist:
|
||||
line = process.stderr.readline()
|
||||
state.stderr.append(line)
|
||||
if process.stdout in rlist:
|
||||
line = process.stdout.readline()
|
||||
state.stdout.append(line)
|
||||
|
||||
log.info("Successfully ran command")
|
||||
return (self.process.stdout, self.process.stderr)
|
||||
state.returncode = process.returncode
|
||||
state.done = True
|
||||
|
||||
POOL: dict[uuid.UUID, BuildVM] = {}
|
||||
if process.returncode != 0:
|
||||
raise NixBuildException(
|
||||
self.uuid, f"Failed to run command: {shlex.join(cmd)}"
|
||||
)
|
||||
|
||||
log.debug("Successfully ran command")
|
||||
return state
|
||||
|
||||
|
||||
class VmTaskPool:
|
||||
def __init__(self) -> None:
|
||||
self.lock: threading.RLock = threading.RLock()
|
||||
self.pool: dict[UUID, BuildVM] = {}
|
||||
|
||||
def __getitem__(self, uuid: str | UUID) -> BuildVM:
|
||||
with self.lock:
|
||||
if type(uuid) is UUID:
|
||||
return self.pool[uuid]
|
||||
else:
|
||||
uuid = UUID(uuid)
|
||||
return self.pool[uuid]
|
||||
|
||||
def __setitem__(self, uuid: UUID, vm: BuildVM) -> None:
|
||||
with self.lock:
|
||||
if uuid in self.pool:
|
||||
raise KeyError(f"VM with uuid {uuid} already exists")
|
||||
if type(uuid) is not UUID:
|
||||
raise TypeError("uuid must be of type UUID")
|
||||
self.pool[uuid] = vm
|
||||
|
||||
|
||||
POOL: VmTaskPool = VmTaskPool()
|
||||
|
||||
|
||||
def nix_build_exception_handler(
|
||||
request: Request, exc: NixBuildException
|
||||
) -> JSONResponse:
|
||||
log.error("NixBuildException: %s", exc)
|
||||
del POOL[exc.uuid]
|
||||
# del POOL[exc.uuid]
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=jsonable_encoder(dict(detail=exc.detail)),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/vms/{uuid}/status")
|
||||
async def get_status(uuid: str) -> VmStatusResponse:
|
||||
global POOL
|
||||
handle = POOL[uuid]
|
||||
|
||||
if handle.process.poll() is None:
|
||||
return VmStatusResponse(running=True, status=0)
|
||||
else:
|
||||
return VmStatusResponse(running=False, status=handle.process.returncode)
|
||||
|
||||
|
||||
|
||||
@router.get("/api/vms/{uuid}/logs")
|
||||
async def get_logs(uuid: str) -> StreamingResponse:
|
||||
async def stream_logs() -> AsyncIterator[str]:
|
||||
global POOL
|
||||
handle = POOL[uuid]
|
||||
for proc in handle.procs.values():
|
||||
while True:
|
||||
if proc.stdout.empty() and proc.stderr.empty() and not proc.done:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
if proc.stdout.empty() and proc.stderr.empty() and proc.done:
|
||||
break
|
||||
for line in proc.stdout:
|
||||
yield line
|
||||
for line in proc.stderr:
|
||||
yield line
|
||||
|
||||
return StreamingResponse(
|
||||
content=stream_logs(),
|
||||
media_type="text/plain",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/vms/create")
|
||||
async def create_vm(vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks) -> StreamingResponse:
|
||||
handle_id = uuid.uuid4()
|
||||
handle = BuildVM(vm, handle_id)
|
||||
async def create_vm(
|
||||
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
|
||||
) -> VmCreateResponse:
|
||||
global POOL
|
||||
uuid = uuid4()
|
||||
handle = BuildVM(vm, uuid)
|
||||
handle.start()
|
||||
POOL[handle_id] = handle
|
||||
return VmCreateResponse(uuid=str(handle_id))
|
||||
|
||||
|
||||
POOL[uuid] = handle
|
||||
return VmCreateResponse(uuid=str(uuid))
|
||||
|
||||
@@ -43,9 +43,16 @@ class VmConfig(BaseModel):
|
||||
memory_size: int
|
||||
graphics: bool
|
||||
|
||||
|
||||
class VmStatusResponse(BaseModel):
|
||||
status: int
|
||||
running: bool
|
||||
|
||||
|
||||
class VmCreateResponse(BaseModel):
|
||||
uuid: str
|
||||
|
||||
|
||||
class VmInspectResponse(BaseModel):
|
||||
config: VmConfig
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
import urllib.request
|
||||
@@ -11,6 +12,8 @@ from typing import Iterator
|
||||
# XXX: can we dynamically load this using nix develop?
|
||||
from uvicorn import run
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def defer_open_browser(base_url: str) -> None:
|
||||
for i in range(5):
|
||||
@@ -24,7 +27,7 @@ def defer_open_browser(base_url: str) -> None:
|
||||
|
||||
@contextmanager
|
||||
def spawn_node_dev_server(host: str, port: int) -> Iterator[None]:
|
||||
logger.info("Starting node dev server...")
|
||||
log.info("Starting node dev server...")
|
||||
path = Path(__file__).parent.parent.parent.parent / "ui"
|
||||
with subprocess.Popen(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user