API: Added endpoint & test for /api/flake/attrs

This commit is contained in:
Qubasa
2023-10-01 12:45:01 +02:00
parent 84bcfc3929
commit ce7ae81a35
7 changed files with 110 additions and 68 deletions

View File

@@ -11,6 +11,18 @@ def nix_command(flags: list[str]) -> list[str]:
return ["nix", "--extra-experimental-features", "nix-command flakes"] + flags return ["nix", "--extra-experimental-features", "nix-command flakes"] + flags
def nix_flake_show(flake_url: str) -> list[str]:
return nix_command(
[
"flake",
"show",
"--json",
"--show-trace",
f"{flake_url}",
]
)
def nix_build( def nix_build(
flags: list[str], flags: list[str],
) -> list[str]: ) -> list[str]:

View File

@@ -7,7 +7,7 @@ from fastapi.staticfiles import StaticFiles
from .. import custom_logger from .. import custom_logger
from .assets import asset_path from .assets import asset_path
from .routers import flake, health, machines, root, vms from .routers import flake, health, machines, root, utils, vms
origins = [ origins = [
"http://localhost:3000", "http://localhost:3000",
@@ -33,7 +33,9 @@ def setup_app() -> FastAPI:
# Needs to be last in register. Because of wildcard route # Needs to be last in register. Because of wildcard route
app.include_router(root.router) app.include_router(root.router)
app.add_exception_handler(vms.NixBuildException, vms.nix_build_exception_handler) app.add_exception_handler(
utils.NixBuildException, utils.nix_build_exception_handler
)
app.mount("/static", StaticFiles(directory=asset_path()), name="static") app.mount("/static", StaticFiles(directory=asset_path()), name="static")

View File

@@ -1,16 +1,26 @@
import asyncio
import json import json
from pathlib import Path from pathlib import Path
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException
from clan_cli.webui.schemas import FlakeAction, FlakeResponse from clan_cli.webui.schemas import FlakeAction, FlakeAttrResponse, FlakeResponse
from ...nix import nix_command from ...nix import nix_command, nix_flake_show
from .utils import run_cmd
router = APIRouter() router = APIRouter()
@router.get("/api/flake/attrs")
async def inspect_flake_attrs(url: str) -> FlakeAttrResponse:
cmd = nix_flake_show(url)
stdout = await run_cmd(cmd)
data = json.loads(stdout)
nixos_configs = data["nixosConfigurations"]
flake_attrs = list(nixos_configs.keys())
return FlakeAttrResponse(flake_attrs=flake_attrs)
@router.get("/api/flake") @router.get("/api/flake")
async def inspect_flake( async def inspect_flake(
url: str, url: str,
@@ -19,17 +29,7 @@ async def inspect_flake(
# Extract the flake from the given URL # Extract the flake from the given URL
# We do this by running 'nix flake prefetch {url} --json' # We do this by running 'nix flake prefetch {url} --json'
cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"]) cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"])
proc = await asyncio.create_subprocess_exec( stdout = await run_cmd(cmd)
cmd[0],
*cmd[1:],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(stderr))
data: dict[str, str] = json.loads(stdout) data: dict[str, str] = json.loads(stdout)
if data.get("storePath") is None: if data.get("storePath") is None:

View File

@@ -0,0 +1,54 @@
import asyncio
import logging
import shlex
from fastapi import HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
log = logging.getLogger(__name__)
class NixBuildException(HTTPException):
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
detail = [
{
"loc": loc,
"msg": msg,
"type": "value_error",
}
]
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail
)
def nix_build_exception_handler(
request: Request, exc: NixBuildException
) -> JSONResponse:
log.error("NixBuildException: %s", exc)
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)),
)
async def run_cmd(cmd: list[str]) -> bytes:
log.debug(f"Running command: {shlex.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
f"""
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
return stdout

View File

@@ -1,17 +1,15 @@
import asyncio
import json import json
import logging import logging
import shlex
from typing import Annotated, Iterator from typing import Annotated, Iterator
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request, status from fastapi import APIRouter, BackgroundTasks, Body
from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from ...nix import nix_build, nix_eval from ...nix import nix_build, nix_eval
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task from ..task_manager import BaseTask, get_task, register_task
from .utils import run_cmd
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@@ -33,20 +31,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
) )
class NixBuildException(HTTPException):
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
detail = [
{
"loc": loc,
"msg": msg,
"type": "value_error",
}
]
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail
)
class BuildVmTask(BaseTask): class BuildVmTask(BaseTask):
def __init__(self, uuid: UUID, vm: VmConfig) -> None: def __init__(self, uuid: UUID, vm: VmConfig) -> None:
super().__init__(uuid) super().__init__(uuid)
@@ -71,43 +55,12 @@ class BuildVmTask(BaseTask):
log.exception(e) log.exception(e)
def nix_build_exception_handler(
request: Request, exc: NixBuildException
) -> JSONResponse:
log.error("NixBuildException: %s", exc)
return JSONResponse(
status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)),
)
##################################
# #
# ======== VM ROUTES ======== #
# #
##################################
@router.post("/api/vms/inspect") @router.post("/api/vms/inspect")
async def inspect_vm( async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()] flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse: ) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url) cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
proc = await asyncio.create_subprocess_exec( stdout = await run_cmd(cmd)
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
f"""
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
data = json.loads(stdout) data = json.loads(stdout)
return VmInspectResponse( return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data) config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)

View File

@@ -53,6 +53,10 @@ class VmCreateResponse(BaseModel):
uuid: str uuid: str
class FlakeAttrResponse(BaseModel):
flake_attrs: list[str]
class VmInspectResponse(BaseModel): class VmInspectResponse(BaseModel):
config: VmConfig config: VmConfig

View File

@@ -0,0 +1,17 @@
from pathlib import Path
import pytest
from api import TestClient
@pytest.mark.impure
def test_inspect(api: TestClient, test_flake_with_core: Path) -> None:
params = {"url": str(test_flake_with_core)}
response = api.get(
"/api/flake/attrs",
params=params,
)
assert response.status_code == 200, "Failed to inspect vm"
data = response.json()
print("Data: ", data)
assert data.get("flake_attrs") == ["vm1"]