Fixed failing tests

This commit is contained in:
Qubasa
2023-10-03 11:51:31 +02:00
parent ab6b96e516
commit 14831e871f
4 changed files with 52 additions and 58 deletions

View File

@@ -1,25 +1,25 @@
import argparse import argparse
import asyncio import asyncio
from typing import Any, Iterator
from uuid import UUID from uuid import UUID
import threading
import queue from fastapi.responses import StreamingResponse
from ..dirs import get_clan_flake_toplevel from ..dirs import get_clan_flake_toplevel
from ..webui.routers import vms from ..webui.routers import vms
from ..webui.schemas import VmConfig from ..webui.schemas import VmConfig
from typing import Any, Iterator
from fastapi.responses import StreamingResponse
import pdb
def read_stream_response(stream: StreamingResponse) -> Iterator[Any]: def read_stream_response(stream: StreamingResponse) -> Iterator[Any]:
iterator = stream.body_iterator iterator = stream.body_iterator
while True: while True:
try: try:
tem = asyncio.run(iterator.__anext__()) tem = asyncio.run(iterator.__anext__()) # type: ignore
except StopAsyncIteration: except StopAsyncIteration:
break break
yield tem yield tem
def create(args: argparse.Namespace) -> None: def create(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix() clan_dir = get_clan_flake_toplevel().as_posix()
vm = VmConfig( vm = VmConfig(
@@ -34,13 +34,12 @@ def create(args: argparse.Namespace) -> None:
print(res.json()) print(res.json())
uuid = UUID(res.uuid) uuid = UUID(res.uuid)
res = asyncio.run(vms.get_vm_logs(uuid)) stream = asyncio.run(vms.get_vm_logs(uuid))
for line in read_stream_response(res): for line in read_stream_response(stream):
print(line) print(line)
def register_create_parser(parser: argparse.ArgumentParser) -> None: def register_create_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument("machine", type=str) parser.add_argument("machine", type=str)
parser.set_defaults(func=create) parser.set_defaults(func=create)

View File

@@ -1,9 +1,8 @@
import json import json
import logging import logging
import tempfile import tempfile
import time
from pathlib import Path from pathlib import Path
from typing import Annotated, Iterator, Iterable from typing import Annotated, Iterator
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Body from fastapi import APIRouter, Body
@@ -11,7 +10,7 @@ from fastapi.responses import StreamingResponse
from ...nix import nix_build, nix_eval, nix_shell from ...nix import nix_build, nix_eval, nix_shell
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task, CmdState from ..task_manager import BaseTask, CmdState, get_task, register_task
from .utils import run_cmd from .utils import run_cmd
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -39,7 +38,7 @@ class BuildVmTask(BaseTask):
super().__init__(uuid) super().__init__(uuid)
self.vm = vm self.vm = vm
def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict: def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url clan_dir = self.vm.flake_url
machine = self.vm.flake_attr machine = self.vm.flake_attr
cmd = next(cmds) cmd = next(cmds)
@@ -71,31 +70,36 @@ class BuildVmTask(BaseTask):
disk_img = f"{tmpdir_}/disk.img" disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds) cmd = next(cmds)
cmd.run(nix_shell( cmd.run(
["qemu"], nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
[ [
"qemu-img", "mkfs.ext4",
"create", "-L",
"-f", "nixos",
"raw",
disk_img, disk_img,
"1024M", ]
], )
))
cmd = next(cmds) cmd = next(cmds)
cmd.run([ cmd.run(
"mkfs.ext4", nix_shell(
"-L", ["qemu"],
"nixos", [
disk_img, # fmt: off
])
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm", "qemu-kvm",
"-name", machine, "-name", machine,
"-m", f'{vm_config["memorySize"]}M', "-m", f'{vm_config["memorySize"]}M',
@@ -113,9 +117,10 @@ class BuildVmTask(BaseTask):
"-kernel", f'{vm_config["toplevel"]}/kernel', "-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"], "-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0', "-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on # fmt: on
], ],
)) )
)
@router.post("/api/vms/inspect") @router.post("/api/vms/inspect")
@@ -144,8 +149,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]: def stream_logs() -> Iterator[str]:
task = get_task(uuid) task = get_task(uuid)
for line in task.logs_iter(): yield from task.logs_iter()
yield line
return StreamingResponse( return StreamingResponse(
content=stream_logs(), content=stream_logs(),

View File

@@ -5,14 +5,14 @@ import select
import shlex import shlex
import subprocess import subprocess
import threading import threading
from typing import Any, Iterable, Iterator from typing import Any, Iterator
from uuid import UUID, uuid4 from uuid import UUID, uuid4
class CmdState: class CmdState:
def __init__(self, log: logging.Logger) -> None: def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log self.log: logging.Logger = log
self.p: subprocess.Popen = None self.p: subprocess.Popen | None = None
self.stdout: list[str] = [] self.stdout: list[str] = []
self.stderr: list[str] = [] self.stderr: list[str] = []
self._output: queue.SimpleQueue = queue.SimpleQueue() self._output: queue.SimpleQueue = queue.SimpleQueue()
@@ -51,7 +51,7 @@ class CmdState:
assert self.p.stderr is not None assert self.p.stderr is not None
line = self.p.stderr.readline() line = self.p.stderr.readline()
if line != "": if line != "":
line = line.strip('\n') line = line.strip("\n")
self.stderr.append(line) self.stderr.append(line)
self.log.debug("stderr: %s", line) self.log.debug("stderr: %s", line)
self._output.put(line) self._output.put(line)
@@ -60,7 +60,7 @@ class CmdState:
assert self.p.stdout is not None assert self.p.stdout is not None
line = self.p.stdout.readline() line = self.p.stdout.readline()
if line != "": if line != "":
line = line.strip('\n') line = line.strip("\n")
self.stdout.append(line) self.stdout.append(line)
self.log.debug("stdout: %s", line) self.log.debug("stdout: %s", line)
self._output.put(line) self._output.put(line)
@@ -93,9 +93,9 @@ class BaseTask(threading.Thread):
for proc in self.procs: for proc in self.procs:
proc.close_queue() proc.close_queue()
self.failed = True self.failed = True
self.finished = True
self.log.exception(e) self.log.exception(e)
finally:
self.finished = True
def task_run(self) -> None: def task_run(self) -> None:
raise NotImplementedError raise NotImplementedError
@@ -120,7 +120,7 @@ class BaseTask(threading.Thread):
break break
yield line yield line
def register_cmds(self, num_cmds: int) -> Iterable[CmdState]: def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
for i in range(num_cmds): for i in range(num_cmds):
cmd = CmdState(self.log) cmd = CmdState(self.log)
self.procs.append(cmd) self.procs.append(cmd)

View File

@@ -74,20 +74,11 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
assert response.status_code == 200, "Failed to get vm status" assert response.status_code == 200, "Failed to get vm status"
response = api.get(f"/api/vms/{uuid}/logs") response = api.get(f"/api/vms/{uuid}/logs")
print("=========FLAKE LOGS==========")
assert isinstance(response.stream, SyncByteStream)
for line in response.stream:
assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8"), end="")
print("=========END LOGS==========")
assert response.status_code == 200, "Failed to get vm logs"
response = api.get(f"/api/vms/{uuid}/logs")
assert isinstance(response.stream, SyncByteStream)
print("=========VM LOGS==========") print("=========VM LOGS==========")
assert isinstance(response.stream, SyncByteStream)
for line in response.stream: for line in response.stream:
assert line != b"", "Failed to get vm logs" assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8"), end="") print(line.decode("utf-8"))
print("=========END LOGS==========") print("=========END LOGS==========")
assert response.status_code == 200, "Failed to get vm logs" assert response.status_code == 200, "Failed to get vm logs"