move out vm logic out of controller

This commit is contained in:
Jörg Thalheim
2023-10-03 16:47:14 +02:00
parent 6de1aeebb9
commit c03effed54
10 changed files with 203 additions and 218 deletions

View File

@@ -5,8 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles
from ..errors import ClanError
from .assets import asset_path
from .routers import flake, health, machines, root, utils, vms
from .error_handlers import clan_error_handler
from .routers import flake, health, machines, root, vms
origins = [
"http://localhost:3000",
@@ -32,9 +34,7 @@ def setup_app() -> FastAPI:
# Needs to be last in register. Because of wildcard route
app.include_router(root.router)
app.add_exception_handler(
utils.NixBuildException, utils.nix_build_exception_handler
)
app.add_exception_handler(ClanError, clan_error_handler)
app.mount("/static", StaticFiles(directory=asset_path()), name="static")

View File

@@ -0,0 +1,23 @@
import logging
from fastapi import Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from ..errors import ClanError
log = logging.getLogger(__name__)
def clan_error_handler(request: Request, exc: ClanError) -> JSONResponse:
log.error("ClanError: %s", exc)
detail = [
{
"loc": [],
"msg": str(exc),
}
]
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder(dict(detail=detail)),
)

View File

@@ -6,15 +6,15 @@ from fastapi import APIRouter, HTTPException
from clan_cli.webui.schemas import FlakeAction, FlakeAttrResponse, FlakeResponse
from ...async_cmd import run
from ...nix import nix_command, nix_flake_show
from .utils import run_cmd
router = APIRouter()
async def get_attrs(url: str) -> list[str]:
cmd = nix_flake_show(url)
stdout = await run_cmd(cmd)
stdout = await run(cmd)
data: dict[str, dict] = {}
try:
@@ -45,7 +45,7 @@ async def inspect_flake(
# Extract the flake from the given URL
# We do this by running 'nix flake prefetch {url} --json'
cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"])
stdout = await run_cmd(cmd)
stdout = await run(cmd)
data: dict[str, str] = json.loads(stdout)
if data.get("storePath") is None:

View File

@@ -1,54 +0,0 @@
import asyncio
import logging
import shlex
from fastapi import HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
log = logging.getLogger(__name__)
class NixBuildException(HTTPException):
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
detail = [
{
"loc": loc,
"msg": msg,
"type": "value_error",
}
]
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail
)
def nix_build_exception_handler(
request: Request, exc: NixBuildException
) -> JSONResponse:
log.error("NixBuildException: %s", exc)
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)),
)
async def run_cmd(cmd: list[str]) -> bytes:
log.debug(f"Running command: {shlex.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
f"""
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
return stdout

View File

@@ -1,7 +1,4 @@
import json
import logging
import tempfile
from pathlib import Path
from typing import Annotated, Iterator
from uuid import UUID
@@ -11,131 +8,20 @@ from fastapi.responses import StreamingResponse
from clan_cli.webui.routers.flake import get_attrs
from ...nix import nix_build, nix_eval, nix_shell
from ...task_manager import get_task
from ...vms import create, inspect
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, CmdState, get_task, register_task
from .utils import run_cmd
log = logging.getLogger(__name__)
router = APIRouter()
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
return nix_eval(
[
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.clan.vm.config"
]
)
def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
return nix_build(
[
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.build.vm"
]
)
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
cmd.run(
nix_build(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
]
)
)
vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
return json.load(f)
def task_run(self) -> None:
cmds = self.register_cmds(4)
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
vm_config = self.get_vm_create_info(cmds)
with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
[
"mkfs.ext4",
"-L",
"nixos",
disk_img,
]
)
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
"-smp", str(vm_config["cores"]),
"-device", "virtio-rng-pci",
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
"-device", "virtio-keyboard",
"-usb",
"-device", "usb-tablet,bus=usb-bus.0",
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
)
)
@router.post("/api/vms/inspect")
async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
stdout = await run_cmd(cmd)
data = json.loads(stdout)
return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
)
config = await inspect.inspect_vm(flake_url, flake_attr)
return VmInspectResponse(config=config)
@router.get("/api/vms/{uuid}/status")
@@ -168,5 +54,5 @@ async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse:
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
)
uuid = register_task(BuildVmTask, vm)
uuid = create.create_vm(vm)
return VmCreateResponse(uuid=str(uuid))

View File

@@ -3,6 +3,8 @@ from typing import List
from pydantic import BaseModel, Field
from ..vms.inspect import VmConfig
class Status(Enum):
ONLINE = "online"
@@ -35,15 +37,6 @@ class SchemaResponse(BaseModel):
schema_: dict = Field(alias="schema")
class VmConfig(BaseModel):
flake_url: str
flake_attr: str
cores: int
memory_size: int
graphics: bool
class VmStatusResponse(BaseModel):
returncode: list[int | None]
running: bool

View File

@@ -1,169 +0,0 @@
import logging
import os
import queue
import select
import shlex
import subprocess
import threading
from typing import Any, Iterator
from uuid import UUID, uuid4
class CmdState:
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen | None = None
self.stdout: list[str] = []
self.stderr: list[str] = []
self._output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
self.running: bool = False
self.cmd_str: str | None = None
self.workdir: str | None = None
def close_queue(self) -> None:
if self.p is not None:
self.returncode = self.p.returncode
self._output.put(None)
self.running = False
self.done = True
def run(self, cmd: list[str]) -> None:
self.running = True
try:
self.cmd_str = shlex.join(cmd)
self.workdir = os.getcwd()
self.log.debug(f"Working directory: {self.workdir}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=self.workdir,
)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
if self.p.stderr in rlist:
assert self.p.stderr is not None
line = self.p.stderr.readline()
if line != "":
line = line.strip("\n")
self.stderr.append(line)
self.log.debug("stderr: %s", line)
self._output.put(line + "\n")
if self.p.stdout in rlist:
assert self.p.stdout is not None
line = self.p.stdout.readline()
if line != "":
line = line.strip("\n")
self.stdout.append(line)
self.log.debug("stdout: %s", line)
self._output.put(line + "\n")
if self.p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
finally:
self.close_queue()
class BaseTask(threading.Thread):
def __init__(self, uuid: UUID) -> None:
# calling parent class constructor
threading.Thread.__init__(self)
# constructor
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[CmdState] = []
self.failed: bool = False
self.finished: bool = False
self.logs_lock = threading.Lock()
def run(self) -> None:
try:
self.task_run()
except Exception as e:
for proc in self.procs:
proc.close_queue()
self.failed = True
self.log.exception(e)
finally:
self.finished = True
def task_run(self) -> None:
raise NotImplementedError
## TODO: If two clients are connected to the same task,
def logs_iter(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
if self.finished:
self.log.debug("log iter: Task is finished")
break
if proc.done:
for line in proc.stderr:
yield line + "\n"
for line in proc.stdout:
yield line + "\n"
continue
while True:
out = proc._output
line = out.get()
if line is None:
break
yield line
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
for i in range(num_cmds):
cmd = CmdState(self.log)
self.procs.append(cmd)
for cmd in self.procs:
yield cmd
# TODO: We need to test concurrency
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: UUID) -> BaseTask:
with self.lock:
return self.pool[uuid]
def __setitem__(self, uuid: UUID, task: BaseTask) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"Task with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = task
POOL: TaskPool = TaskPool()
def get_task(uuid: UUID) -> BaseTask:
global POOL
return POOL[uuid]
def register_task(task: type, *args: Any) -> UUID:
global POOL
if not issubclass(task, BaseTask):
raise TypeError("task must be a subclass of BaseTask")
uuid = uuid4()
inst_task = task(uuid, *args)
POOL[uuid] = inst_task
inst_task.start()
return uuid