Fixed failing tests

This commit is contained in:
Qubasa
2023-10-03 11:51:31 +02:00
parent 3a8ce96b43
commit c78af6243c
4 changed files with 52 additions and 58 deletions

View File

@@ -1,9 +1,8 @@
import json
import logging
import tempfile
import time
from pathlib import Path
from typing import Annotated, Iterator, Iterable
from typing import Annotated, Iterator
from uuid import UUID
from fastapi import APIRouter, Body
@@ -16,7 +15,7 @@ from clan_cli.webui.routers.flake import get_attrs
from ...nix import nix_build, nix_eval
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task, CmdState
from ..task_manager import BaseTask, CmdState, get_task, register_task
from .utils import run_cmd
log = logging.getLogger(__name__)
@@ -44,7 +43,7 @@ class BuildVmTask(BaseTask):
super().__init__(uuid)
self.vm = vm
def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict:
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
clan_dir = self.vm.flake_url
machine = self.vm.flake_attr
cmd = next(cmds)
@@ -76,31 +75,36 @@ class BuildVmTask(BaseTask):
disk_img = f"{tmpdir_}/disk.img"
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
cmd.run(
nix_shell(
["qemu"],
[
"qemu-img",
"create",
"-f",
"raw",
disk_img,
"1024M",
],
)
)
cmd = next(cmds)
cmd.run(
[
"qemu-img",
"create",
"-f",
"raw",
"mkfs.ext4",
"-L",
"nixos",
disk_img,
"1024M",
],
))
]
)
cmd = next(cmds)
cmd.run([
"mkfs.ext4",
"-L",
"nixos",
disk_img,
])
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"],
[
# fmt: off
cmd.run(
nix_shell(
["qemu"],
[
# fmt: off
"qemu-kvm",
"-name", machine,
"-m", f'{vm_config["memorySize"]}M',
@@ -118,9 +122,10 @@ class BuildVmTask(BaseTask):
"-kernel", f'{vm_config["toplevel"]}/kernel',
"-initrd", vm_config["initrd"],
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on
],
))
# fmt: on
],
)
)
@router.post("/api/vms/inspect")
@@ -149,8 +154,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]:
task = get_task(uuid)
for line in task.logs_iter():
yield line
yield from task.logs_iter()
return StreamingResponse(
content=stream_logs(),

View File

@@ -5,14 +5,14 @@ import select
import shlex
import subprocess
import threading
from typing import Any, Iterable, Iterator
from typing import Any, Iterator
from uuid import UUID, uuid4
class CmdState:
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen = None
self.p: subprocess.Popen | None = None
self.stdout: list[str] = []
self.stderr: list[str] = []
self._output: queue.SimpleQueue = queue.SimpleQueue()
@@ -51,7 +51,7 @@ class CmdState:
assert self.p.stderr is not None
line = self.p.stderr.readline()
if line != "":
line = line.strip('\n')
line = line.strip("\n")
self.stderr.append(line)
self.log.debug("stderr: %s", line)
self._output.put(line)
@@ -60,7 +60,7 @@ class CmdState:
assert self.p.stdout is not None
line = self.p.stdout.readline()
if line != "":
line = line.strip('\n')
line = line.strip("\n")
self.stdout.append(line)
self.log.debug("stdout: %s", line)
self._output.put(line)
@@ -93,14 +93,14 @@ class BaseTask(threading.Thread):
for proc in self.procs:
proc.close_queue()
self.failed = True
self.finished = True
self.log.exception(e)
finally:
self.finished = True
def task_run(self) -> None:
raise NotImplementedError
## TODO: If two clients are connected to the same task,
## TODO: If two clients are connected to the same task,
def logs_iter(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
@@ -120,7 +120,7 @@ class BaseTask(threading.Thread):
break
yield line
def register_cmds(self, num_cmds: int) -> Iterable[CmdState]:
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
for i in range(num_cmds):
cmd = CmdState(self.log)
self.procs.append(cmd)