move out vm logic out of controller

This commit is contained in:
Jörg Thalheim
2023-10-03 16:47:14 +02:00
parent 6de1aeebb9
commit c03effed54
10 changed files with 203 additions and 218 deletions

View File

@@ -0,0 +1,28 @@
import asyncio
import logging
import shlex
from .errors import ClanError
log = logging.getLogger(__name__)
async def run(cmd: list[str]) -> bytes:
log.debug(f"$: {shlex.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise ClanError(
f"""
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
return stdout

View File

@@ -1,26 +1,114 @@
import argparse import argparse
import asyncio import json
from typing import Any, Iterator import tempfile
from pathlib import Path
from typing import Iterator
from uuid import UUID from uuid import UUID
from fastapi.responses import StreamingResponse
from ..dirs import get_clan_flake_toplevel from ..dirs import get_clan_flake_toplevel
from ..webui.routers import vms from ..nix import nix_build, nix_shell
from ..webui.schemas import VmConfig from ..task_manager import BaseTask, CmdState, get_task, register_task
from .inspect import VmConfig
def read_stream_response(stream: StreamingResponse) -> Iterator[Any]: class BuildVmTask(BaseTask):
iterator = stream.body_iterator def __init__(self, uuid: UUID, vm: VmConfig) -> None:
while True: super().__init__(uuid)
try: self.vm = vm
tem = asyncio.run(iterator.__anext__()) # type: ignore
except StopAsyncIteration: def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
break clan_dir = self.vm.flake_url
yield tem machine = self.vm.flake_attr
cmd = next(cmds)
cmd.run(
nix_build(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
]
)
)
vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
return json.load(f)
def task_run(self) -> None:
cmds = self.register_cmds(4)
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
vm_config = self.get_vm_create_info(cmds)
with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
nix_shell(
["e2fsprogs"],
[
"mkfs.ext4",
"-L",
"nixos",
disk_img,
],
)
)
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
"-smp", str(vm_config["cores"]),
"-device", "virtio-rng-pci",
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
"-device", "virtio-keyboard",
"-usb",
"-device", "usb-tablet,bus=usb-bus.0",
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
)
)
def create(args: argparse.Namespace) -> None: def create_vm(vm: VmConfig) -> UUID:
return register_task(BuildVmTask, vm)
def create_command(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix() clan_dir = get_clan_flake_toplevel().as_posix()
vm = VmConfig( vm = VmConfig(
flake_url=clan_dir, flake_url=clan_dir,
@@ -30,17 +118,12 @@ def create(args: argparse.Namespace) -> None:
memory_size=0, memory_size=0,
) )
res = asyncio.run(vms.create_vm(vm)) uuid = create_vm(vm)
print(res.json()) task = get_task(uuid)
uuid = UUID(res.uuid) for line in task.logs_iter():
stream = asyncio.run(vms.get_vm_logs(uuid))
for line in read_stream_response(stream):
print(line, end="") print(line, end="")
print("")
def register_create_parser(parser: argparse.ArgumentParser) -> None: def register_create_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("machine", type=str) parser.add_argument("machine", type=str)
parser.set_defaults(func=create) parser.set_defaults(func=create_command)

View File

@@ -1,16 +1,42 @@
import argparse import argparse
import asyncio import asyncio
import json
from pydantic import BaseModel
from ..async_cmd import run
from ..dirs import get_clan_flake_toplevel from ..dirs import get_clan_flake_toplevel
from ..webui.routers import vms from ..nix import nix_eval
def inspect(args: argparse.Namespace) -> None: class VmConfig(BaseModel):
flake_url: str
flake_attr: str
cores: int
memory_size: int
graphics: bool
async def inspect_vm(flake_url: str, flake_attr: str) -> VmConfig:
cmd = nix_eval(
[
f"{flake_url}#nixosConfigurations.{json.dumps(flake_attr)}.config.system.clan.vm.config"
]
)
stdout = await run(cmd)
data = json.loads(stdout)
return VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
def inspect_command(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix() clan_dir = get_clan_flake_toplevel().as_posix()
res = asyncio.run(vms.inspect_vm(flake_url=clan_dir, flake_attr=args.machine)) res = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
print(res.json()) print("Cores:", res.cores)
print("Memory size:", res.memory_size)
print("Graphics:", res.graphics)
def register_inspect_parser(parser: argparse.ArgumentParser) -> None: def register_inspect_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("machine", type=str) parser.add_argument("machine", type=str)
parser.set_defaults(func=inspect) parser.set_defaults(func=inspect_command)

View File

@@ -5,8 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from ..errors import ClanError
from .assets import asset_path from .assets import asset_path
from .routers import flake, health, machines, root, utils, vms from .error_handlers import clan_error_handler
from .routers import flake, health, machines, root, vms
origins = [ origins = [
"http://localhost:3000", "http://localhost:3000",
@@ -32,9 +34,7 @@ def setup_app() -> FastAPI:
# Needs to be last in register. Because of wildcard route # Needs to be last in register. Because of wildcard route
app.include_router(root.router) app.include_router(root.router)
app.add_exception_handler( app.add_exception_handler(ClanError, clan_error_handler)
utils.NixBuildException, utils.nix_build_exception_handler
)
app.mount("/static", StaticFiles(directory=asset_path()), name="static") app.mount("/static", StaticFiles(directory=asset_path()), name="static")

View File

@@ -0,0 +1,23 @@
import logging
from fastapi import Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from ..errors import ClanError
log = logging.getLogger(__name__)
def clan_error_handler(request: Request, exc: ClanError) -> JSONResponse:
log.error("ClanError: %s", exc)
detail = [
{
"loc": [],
"msg": str(exc),
}
]
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder(dict(detail=detail)),
)

View File

@@ -6,15 +6,15 @@ from fastapi import APIRouter, HTTPException
from clan_cli.webui.schemas import FlakeAction, FlakeAttrResponse, FlakeResponse from clan_cli.webui.schemas import FlakeAction, FlakeAttrResponse, FlakeResponse
from ...async_cmd import run
from ...nix import nix_command, nix_flake_show from ...nix import nix_command, nix_flake_show
from .utils import run_cmd
router = APIRouter() router = APIRouter()
async def get_attrs(url: str) -> list[str]: async def get_attrs(url: str) -> list[str]:
cmd = nix_flake_show(url) cmd = nix_flake_show(url)
stdout = await run_cmd(cmd) stdout = await run(cmd)
data: dict[str, dict] = {} data: dict[str, dict] = {}
try: try:
@@ -45,7 +45,7 @@ async def inspect_flake(
# Extract the flake from the given URL # Extract the flake from the given URL
# We do this by running 'nix flake prefetch {url} --json' # We do this by running 'nix flake prefetch {url} --json'
cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"]) cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"])
stdout = await run_cmd(cmd) stdout = await run(cmd)
data: dict[str, str] = json.loads(stdout) data: dict[str, str] = json.loads(stdout)
if data.get("storePath") is None: if data.get("storePath") is None:

View File

@@ -1,54 +0,0 @@
import asyncio
import logging
import shlex
from fastapi import HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
log = logging.getLogger(__name__)
class NixBuildException(HTTPException):
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
detail = [
{
"loc": loc,
"msg": msg,
"type": "value_error",
}
]
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail
)
def nix_build_exception_handler(
request: Request, exc: NixBuildException
) -> JSONResponse:
log.error("NixBuildException: %s", exc)
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)),
)
async def run_cmd(cmd: list[str]) -> bytes:
log.debug(f"Running command: {shlex.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
f"""
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
return stdout

View File

@@ -1,7 +1,4 @@
import json
import logging import logging
import tempfile
from pathlib import Path
from typing import Annotated, Iterator from typing import Annotated, Iterator
from uuid import UUID from uuid import UUID
@@ -11,131 +8,20 @@ from fastapi.responses import StreamingResponse
from clan_cli.webui.routers.flake import get_attrs from clan_cli.webui.routers.flake import get_attrs
from ...nix import nix_build, nix_eval, nix_shell from ...task_manager import get_task
from ...vms import create, inspect
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, CmdState, get_task, register_task
from .utils import run_cmd
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
return nix_eval(
[
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.clan.vm.config"
]
)
def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
return nix_build(
[
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.build.vm"
]
)
class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
cmd.run(
nix_build(
[
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
]
)
)
vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f:
return json.load(f)
def task_run(self) -> None:
cmds = self.register_cmds(4)
machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}")
# TODO: We should get this from the vm argument
vm_config = self.get_vm_create_info(cmds)
with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
[
"mkfs.ext4",
"-L",
"nixos",
disk_img,
]
)
cmd = next(cmds)
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
"-smp", str(vm_config["cores"]),
"-device", "virtio-rng-pci",
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
"-device", "virtio-keyboard",
"-usb",
"-device", "usb-tablet,bus=usb-bus.0",
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
)
)
@router.post("/api/vms/inspect") @router.post("/api/vms/inspect")
async def inspect_vm( async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()] flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse: ) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url) config = await inspect.inspect_vm(flake_url, flake_attr)
stdout = await run_cmd(cmd) return VmInspectResponse(config=config)
data = json.loads(stdout)
return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
)
@router.get("/api/vms/{uuid}/status") @router.get("/api/vms/{uuid}/status")
@@ -168,5 +54,5 @@ async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse:
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Provided attribute '{vm.flake_attr}' does not exist.", detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
) )
uuid = register_task(BuildVmTask, vm) uuid = create.create_vm(vm)
return VmCreateResponse(uuid=str(uuid)) return VmCreateResponse(uuid=str(uuid))

View File

@@ -3,6 +3,8 @@ from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..vms.inspect import VmConfig
class Status(Enum): class Status(Enum):
ONLINE = "online" ONLINE = "online"
@@ -35,15 +37,6 @@ class SchemaResponse(BaseModel):
schema_: dict = Field(alias="schema") schema_: dict = Field(alias="schema")
class VmConfig(BaseModel):
flake_url: str
flake_attr: str
cores: int
memory_size: int
graphics: bool
class VmStatusResponse(BaseModel): class VmStatusResponse(BaseModel):
returncode: list[int | None] returncode: list[int | None]
running: bool running: bool