Improving endpoint
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class CustomFormatter(logging.Formatter):
|
class CustomFormatter(logging.Formatter):
|
||||||
|
|
||||||
grey = "\x1b[38;20m"
|
grey = "\x1b[38;20m"
|
||||||
yellow = "\x1b[33;20m"
|
yellow = "\x1b[33;20m"
|
||||||
red = "\x1b[31;20m"
|
red = "\x1b[31;20m"
|
||||||
@@ -11,7 +11,8 @@ class CustomFormatter(logging.Formatter):
|
|||||||
green = "\u001b[32m"
|
green = "\u001b[32m"
|
||||||
blue = "\u001b[34m"
|
blue = "\u001b[34m"
|
||||||
|
|
||||||
def format_str(color):
|
@staticmethod
|
||||||
|
def format_str(color: str) -> str:
|
||||||
reset = "\x1b[0m"
|
reset = "\x1b[0m"
|
||||||
return f"{color}%(levelname)s{reset}:(%(filename)s:%(lineno)d): %(message)s"
|
return f"{color}%(levelname)s{reset}:(%(filename)s:%(lineno)d): %(message)s"
|
||||||
|
|
||||||
@@ -20,24 +21,23 @@ class CustomFormatter(logging.Formatter):
|
|||||||
logging.INFO: format_str(green),
|
logging.INFO: format_str(green),
|
||||||
logging.WARNING: format_str(yellow),
|
logging.WARNING: format_str(yellow),
|
||||||
logging.ERROR: format_str(red),
|
logging.ERROR: format_str(red),
|
||||||
logging.CRITICAL: format_str(bold_red)
|
logging.CRITICAL: format_str(bold_red),
|
||||||
}
|
}
|
||||||
|
|
||||||
def formatTime(self, record,datefmt=None):
|
def format_time(self, record: Any, datefmt: Any = None) -> str:
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
now = now.strftime("%H:%M:%S")
|
now = now.strftime("%H:%M:%S")
|
||||||
return now
|
return now
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record: Any) -> str:
|
||||||
log_fmt = self.FORMATS.get(record.levelno)
|
log_fmt = self.FORMATS.get(record.levelno)
|
||||||
formatter = logging.Formatter(log_fmt)
|
formatter = logging.Formatter(log_fmt)
|
||||||
formatter.formatTime = self.formatTime
|
formatter.formatTime = self.format_time
|
||||||
return formatter.format(record)
|
return formatter.format(record)
|
||||||
|
|
||||||
|
|
||||||
def register(level):
|
def register(level: Any) -> None:
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
ch.setLevel(level)
|
ch.setLevel(level)
|
||||||
ch.setFormatter(CustomFormatter())
|
ch.setFormatter(CustomFormatter())
|
||||||
logging.basicConfig(level=level, handlers=[ch])
|
logging.basicConfig(level=level, handlers=[ch])
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
import logging
|
|
||||||
|
|
||||||
from .. import custom_logger
|
from .. import custom_logger
|
||||||
from .assets import asset_path
|
from .assets import asset_path
|
||||||
from .routers import flake, health, machines, root, vms
|
from .routers import flake, health, machines, root, vms
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ def setup_app() -> FastAPI:
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
#TODO: How do I get the log level from the command line in here?
|
# TODO: How do I get the log level from the command line in here?
|
||||||
custom_logger.register(logging.DEBUG)
|
custom_logger.register(logging.DEBUG)
|
||||||
app = setup_app()
|
app = setup_app()
|
||||||
|
|
||||||
|
|||||||
@@ -2,16 +2,27 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import select
|
||||||
|
import queue
|
||||||
import shlex
|
import shlex
|
||||||
import uuid
|
import subprocess
|
||||||
|
import threading
|
||||||
from typing import Annotated, AsyncIterator
|
from typing import Annotated, AsyncIterator
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, FastAPI, HTTPException, Request, status, BackgroundTasks
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
BackgroundTasks,
|
||||||
|
Body,
|
||||||
|
HTTPException,
|
||||||
|
Request,
|
||||||
|
status,
|
||||||
|
)
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from ...nix import nix_build, nix_eval
|
from ...nix import nix_build, nix_eval
|
||||||
from ..schemas import VmConfig, VmInspectResponse, VmCreateResponse
|
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
|
||||||
|
|
||||||
# Logging setup
|
# Logging setup
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -37,8 +48,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/vms/inspect")
|
@router.post("/api/vms/inspect")
|
||||||
async def inspect_vm(
|
async def inspect_vm(
|
||||||
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
|
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
|
||||||
@@ -68,9 +77,8 @@ command output:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class NixBuildException(HTTPException):
|
class NixBuildException(HTTPException):
|
||||||
def __init__(self, uuid: uuid.UUID, msg: str,loc: list = ["body", "flake_attr"]):
|
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
|
||||||
self.uuid = uuid
|
self.uuid = uuid
|
||||||
detail = [
|
detail = [
|
||||||
{
|
{
|
||||||
@@ -85,74 +93,161 @@ class NixBuildException(HTTPException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessState:
|
||||||
import threading
|
def __init__(self, proc: subprocess.Popen):
|
||||||
import subprocess
|
self.proc: subprocess.Process = proc
|
||||||
import uuid
|
self.stdout: list[str] = []
|
||||||
|
self.stderr: list[str] = []
|
||||||
|
self.returncode: int | None = None
|
||||||
|
self.done: bool = False
|
||||||
|
|
||||||
class BuildVM(threading.Thread):
|
class BuildVM(threading.Thread):
|
||||||
def __init__(self, vm: VmConfig, uuid: uuid.UUID):
|
def __init__(self, vm: VmConfig, uuid: UUID):
|
||||||
# calling parent class constructor
|
# calling parent class constructor
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
|
|
||||||
# constructor
|
# constructor
|
||||||
self.vm: VmConfig = vm
|
self.vm: VmConfig = vm
|
||||||
self.uuid: uuid.UUID = uuid
|
self.uuid: UUID = uuid
|
||||||
self.log = logging.getLogger(__name__)
|
self.log = logging.getLogger(__name__)
|
||||||
self.process: subprocess.Popen = None
|
self.procs: list[ProcessState] = []
|
||||||
|
self.failed: bool = False
|
||||||
|
self.finished: bool = False
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.log.debug(f"BuildVM with uuid {self.uuid} started")
|
try:
|
||||||
|
|
||||||
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
|
self.log.debug(f"BuildVM with uuid {self.uuid} started")
|
||||||
(out, err) = self.run_cmd(cmd)
|
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
|
||||||
vm_path = f'{out.strip()}/bin/run-nixos-vm'
|
|
||||||
|
|
||||||
self.log.debug(f"vm_path: {vm_path}")
|
proc = self.run_cmd(cmd)
|
||||||
|
out = proc.stdout
|
||||||
|
self.log.debug(f"out: {out}")
|
||||||
|
|
||||||
(out, err) = self.run_cmd(vm_path)
|
vm_path = f"{''.join(out)}/bin/run-nixos-vm"
|
||||||
|
self.log.debug(f"vm_path: {vm_path}")
|
||||||
|
self.run_cmd(vm_path)
|
||||||
|
|
||||||
|
self.finished = True
|
||||||
|
except Exception as e:
|
||||||
|
self.failed = True
|
||||||
|
self.finished = True
|
||||||
|
log.exception(e)
|
||||||
|
|
||||||
def run_cmd(self, cmd: list[str]):
|
def run_cmd(self, cmd: list[str]) -> ProcessState:
|
||||||
cwd=os.getcwd()
|
cwd = os.getcwd()
|
||||||
log.debug(f"Working directory: {cwd}")
|
log.debug(f"Working directory: {cwd}")
|
||||||
log.debug(f"Running command: {shlex.join(cmd)}")
|
log.debug(f"Running command: {shlex.join(cmd)}")
|
||||||
self.process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
cmd,
|
cmd,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
)
|
)
|
||||||
|
state = ProcessState(process)
|
||||||
|
self.procs.append(state)
|
||||||
|
|
||||||
self.process.wait()
|
while process.poll() is None:
|
||||||
if self.process.returncode != 0:
|
# Check if stderr is ready to be read from
|
||||||
raise NixBuildException(self.uuid, f"Failed to run command: {shlex.join(cmd)}")
|
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
|
||||||
|
if process.stderr in rlist:
|
||||||
|
line = process.stderr.readline()
|
||||||
|
state.stderr.append(line)
|
||||||
|
if process.stdout in rlist:
|
||||||
|
line = process.stdout.readline()
|
||||||
|
state.stdout.append(line)
|
||||||
|
|
||||||
log.info("Successfully ran command")
|
state.returncode = process.returncode
|
||||||
return (self.process.stdout, self.process.stderr)
|
state.done = True
|
||||||
|
|
||||||
POOL: dict[uuid.UUID, BuildVM] = {}
|
if process.returncode != 0:
|
||||||
|
raise NixBuildException(
|
||||||
|
self.uuid, f"Failed to run command: {shlex.join(cmd)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
log.debug("Successfully ran command")
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
class VmTaskPool:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.lock: threading.RLock = threading.RLock()
|
||||||
|
self.pool: dict[UUID, BuildVM] = {}
|
||||||
|
|
||||||
|
def __getitem__(self, uuid: str | UUID) -> BuildVM:
|
||||||
|
with self.lock:
|
||||||
|
if type(uuid) is UUID:
|
||||||
|
return self.pool[uuid]
|
||||||
|
else:
|
||||||
|
uuid = UUID(uuid)
|
||||||
|
return self.pool[uuid]
|
||||||
|
|
||||||
|
def __setitem__(self, uuid: UUID, vm: BuildVM) -> None:
|
||||||
|
with self.lock:
|
||||||
|
if uuid in self.pool:
|
||||||
|
raise KeyError(f"VM with uuid {uuid} already exists")
|
||||||
|
if type(uuid) is not UUID:
|
||||||
|
raise TypeError("uuid must be of type UUID")
|
||||||
|
self.pool[uuid] = vm
|
||||||
|
|
||||||
|
|
||||||
|
POOL: VmTaskPool = VmTaskPool()
|
||||||
|
|
||||||
|
|
||||||
def nix_build_exception_handler(
|
def nix_build_exception_handler(
|
||||||
request: Request, exc: NixBuildException
|
request: Request, exc: NixBuildException
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
log.error("NixBuildException: %s", exc)
|
log.error("NixBuildException: %s", exc)
|
||||||
del POOL[exc.uuid]
|
# del POOL[exc.uuid]
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=exc.status_code,
|
status_code=exc.status_code,
|
||||||
content=jsonable_encoder(dict(detail=exc.detail)),
|
content=jsonable_encoder(dict(detail=exc.detail)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/vms/{uuid}/status")
|
||||||
|
async def get_status(uuid: str) -> VmStatusResponse:
|
||||||
|
global POOL
|
||||||
|
handle = POOL[uuid]
|
||||||
|
|
||||||
|
if handle.process.poll() is None:
|
||||||
|
return VmStatusResponse(running=True, status=0)
|
||||||
|
else:
|
||||||
|
return VmStatusResponse(running=False, status=handle.process.returncode)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/vms/{uuid}/logs")
|
||||||
|
async def get_logs(uuid: str) -> StreamingResponse:
|
||||||
|
async def stream_logs() -> AsyncIterator[str]:
|
||||||
|
global POOL
|
||||||
|
handle = POOL[uuid]
|
||||||
|
for proc in handle.procs.values():
|
||||||
|
while True:
|
||||||
|
if proc.stdout.empty() and proc.stderr.empty() and not proc.done:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
continue
|
||||||
|
if proc.stdout.empty() and proc.stderr.empty() and proc.done:
|
||||||
|
break
|
||||||
|
for line in proc.stdout:
|
||||||
|
yield line
|
||||||
|
for line in proc.stderr:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
content=stream_logs(),
|
||||||
|
media_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/vms/create")
|
@router.post("/api/vms/create")
|
||||||
async def create_vm(vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks) -> StreamingResponse:
|
async def create_vm(
|
||||||
handle_id = uuid.uuid4()
|
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
|
||||||
handle = BuildVM(vm, handle_id)
|
) -> VmCreateResponse:
|
||||||
|
global POOL
|
||||||
|
uuid = uuid4()
|
||||||
|
handle = BuildVM(vm, uuid)
|
||||||
handle.start()
|
handle.start()
|
||||||
POOL[handle_id] = handle
|
POOL[uuid] = handle
|
||||||
return VmCreateResponse(uuid=str(handle_id))
|
return VmCreateResponse(uuid=str(uuid))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -43,9 +43,16 @@ class VmConfig(BaseModel):
|
|||||||
memory_size: int
|
memory_size: int
|
||||||
graphics: bool
|
graphics: bool
|
||||||
|
|
||||||
|
|
||||||
|
class VmStatusResponse(BaseModel):
|
||||||
|
status: int
|
||||||
|
running: bool
|
||||||
|
|
||||||
|
|
||||||
class VmCreateResponse(BaseModel):
|
class VmCreateResponse(BaseModel):
|
||||||
uuid: str
|
uuid: str
|
||||||
|
|
||||||
|
|
||||||
class VmInspectResponse(BaseModel):
|
class VmInspectResponse(BaseModel):
|
||||||
config: VmConfig
|
config: VmConfig
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import urllib.request
|
import urllib.request
|
||||||
@@ -11,6 +12,8 @@ from typing import Iterator
|
|||||||
# XXX: can we dynamically load this using nix develop?
|
# XXX: can we dynamically load this using nix develop?
|
||||||
from uvicorn import run
|
from uvicorn import run
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def defer_open_browser(base_url: str) -> None:
|
def defer_open_browser(base_url: str) -> None:
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -24,7 +27,7 @@ def defer_open_browser(base_url: str) -> None:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def spawn_node_dev_server(host: str, port: int) -> Iterator[None]:
|
def spawn_node_dev_server(host: str, port: int) -> Iterator[None]:
|
||||||
logger.info("Starting node dev server...")
|
log.info("Starting node dev server...")
|
||||||
path = Path(__file__).parent.parent.parent.parent / "ui"
|
path = Path(__file__).parent.parent.parent.parent / "ui"
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user