Fixed failing tests

This commit is contained in:
Qubasa
2023-10-03 11:51:31 +02:00
parent 7e180d2f12
commit c1c68ee1d8
4 changed files with 52 additions and 58 deletions

View File

@@ -1,9 +1,8 @@
import json
import logging
import tempfile
import time
from pathlib import Path
from typing import Annotated, Iterator, Iterable
from typing import Annotated, Iterator
from uuid import UUID
from fastapi import APIRouter, Body
@@ -16,7 +15,7 @@ from clan_cli.webui.routers.flake import get_attrs
from ...nix import nix_build, nix_eval
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task, CmdState
from ..task_manager import BaseTask, CmdState, get_task, register_task
from .utils import run_cmd
log = logging.getLogger(__name__)
@@ -44,7 +43,7 @@ class BuildVmTask(BaseTask):
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict:
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
@@ -76,31 +75,36 @@ class BuildVmTask(BaseTask):
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
[
"qemu-img",
"create",
"-f",
"raw",
"mkfs.ext4",
"-L",
"nixos",
disk_img,
"1024M",
],
))
]
)
cmd = next(cmds)
cmd.run([
"mkfs.ext4",
"-L",
"nixos",
disk_img,
])
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
[
# fmt: off
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
@@ -118,9 +122,10 @@ class BuildVmTask(BaseTask):
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
))
# fmt: on
],
)
)
@router.post("/api/vms/inspect")
@@ -149,8 +154,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]:
task = get_task(uuid)
for line in task.logs_iter():
yield line
yield from task.logs_iter()
return StreamingResponse(
content=stream_logs(),