Fixed failing tests
This commit is contained in:
@@ -1,25 +1,25 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import Any, Iterator
|
||||
from uuid import UUID
|
||||
import threading
|
||||
import queue
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..dirs import get_clan_flake_toplevel
|
||||
from ..webui.routers import vms
|
||||
from ..webui.schemas import VmConfig
|
||||
from typing import Any, Iterator
|
||||
from fastapi.responses import StreamingResponse
|
||||
import pdb
|
||||
|
||||
|
||||
def read_stream_response(stream: StreamingResponse) -> Iterator[Any]:
|
||||
iterator = stream.body_iterator
|
||||
while True:
|
||||
try:
|
||||
tem = asyncio.run(iterator.__anext__())
|
||||
tem = asyncio.run(iterator.__anext__()) # type: ignore
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
yield tem
|
||||
|
||||
|
||||
def create(args: argparse.Namespace) -> None:
|
||||
clan_dir = get_clan_flake_toplevel().as_posix()
|
||||
vm = VmConfig(
|
||||
@@ -34,13 +34,12 @@ def create(args: argparse.Namespace) -> None:
|
||||
print(res.json())
|
||||
uuid = UUID(res.uuid)
|
||||
|
||||
res = asyncio.run(vms.get_vm_logs(uuid))
|
||||
stream = asyncio.run(vms.get_vm_logs(uuid))
|
||||
|
||||
for line in read_stream_response(res):
|
||||
for line in read_stream_response(stream):
|
||||
print(line)
|
||||
|
||||
|
||||
|
||||
def register_create_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("machine", type=str)
|
||||
parser.set_defaults(func=create)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Iterator, Iterable
|
||||
from typing import Annotated, Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Body
|
||||
@@ -11,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
from ...nix import nix_build, nix_eval, nix_shell
|
||||
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
|
||||
from ..task_manager import BaseTask, get_task, register_task, CmdState
|
||||
from ..task_manager import BaseTask, CmdState, get_task, register_task
|
||||
from .utils import run_cmd
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -39,7 +38,7 @@ class BuildVmTask(BaseTask):
|
||||
super().__init__(uuid)
|
||||
self.vm = vm
|
||||
|
||||
def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict:
|
||||
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
|
||||
clan_dir = self.vm.flake_url
|
||||
machine = self.vm.flake_attr
|
||||
cmd = next(cmds)
|
||||
@@ -71,31 +70,36 @@ class BuildVmTask(BaseTask):
|
||||
disk_img = f"{tmpdir_}/disk.img"
|
||||
|
||||
cmd = next(cmds)
|
||||
cmd.run(nix_shell(
|
||||
["qemu"],
|
||||
cmd.run(
|
||||
nix_shell(
|
||||
["qemu"],
|
||||
[
|
||||
"qemu-img",
|
||||
"create",
|
||||
"-f",
|
||||
"raw",
|
||||
disk_img,
|
||||
"1024M",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
cmd = next(cmds)
|
||||
cmd.run(
|
||||
[
|
||||
"qemu-img",
|
||||
"create",
|
||||
"-f",
|
||||
"raw",
|
||||
"mkfs.ext4",
|
||||
"-L",
|
||||
"nixos",
|
||||
disk_img,
|
||||
"1024M",
|
||||
],
|
||||
))
|
||||
]
|
||||
)
|
||||
|
||||
cmd = next(cmds)
|
||||
cmd.run([
|
||||
"mkfs.ext4",
|
||||
"-L",
|
||||
"nixos",
|
||||
disk_img,
|
||||
])
|
||||
|
||||
cmd = next(cmds)
|
||||
cmd.run(nix_shell(
|
||||
["qemu"],
|
||||
[
|
||||
# fmt: off
|
||||
cmd.run(
|
||||
nix_shell(
|
||||
["qemu"],
|
||||
[
|
||||
# fmt: off
|
||||
"qemu-kvm",
|
||||
"-name", machine,
|
||||
"-m", f'{vm_config["memorySize"]}M',
|
||||
@@ -113,9 +117,10 @@ class BuildVmTask(BaseTask):
|
||||
"-kernel", f'{vm_config["toplevel"]}/kernel',
|
||||
"-initrd", vm_config["initrd"],
|
||||
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
|
||||
# fmt: on
|
||||
],
|
||||
))
|
||||
# fmt: on
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/vms/inspect")
|
||||
@@ -144,8 +149,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
|
||||
def stream_logs() -> Iterator[str]:
|
||||
task = get_task(uuid)
|
||||
|
||||
for line in task.logs_iter():
|
||||
yield line
|
||||
yield from task.logs_iter()
|
||||
|
||||
return StreamingResponse(
|
||||
content=stream_logs(),
|
||||
|
||||
@@ -5,14 +5,14 @@ import select
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Any, Iterable, Iterator
|
||||
from typing import Any, Iterator
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
class CmdState:
|
||||
def __init__(self, log: logging.Logger) -> None:
|
||||
self.log: logging.Logger = log
|
||||
self.p: subprocess.Popen = None
|
||||
self.p: subprocess.Popen | None = None
|
||||
self.stdout: list[str] = []
|
||||
self.stderr: list[str] = []
|
||||
self._output: queue.SimpleQueue = queue.SimpleQueue()
|
||||
@@ -51,7 +51,7 @@ class CmdState:
|
||||
assert self.p.stderr is not None
|
||||
line = self.p.stderr.readline()
|
||||
if line != "":
|
||||
line = line.strip('\n')
|
||||
line = line.strip("\n")
|
||||
self.stderr.append(line)
|
||||
self.log.debug("stderr: %s", line)
|
||||
self._output.put(line)
|
||||
@@ -60,7 +60,7 @@ class CmdState:
|
||||
assert self.p.stdout is not None
|
||||
line = self.p.stdout.readline()
|
||||
if line != "":
|
||||
line = line.strip('\n')
|
||||
line = line.strip("\n")
|
||||
self.stdout.append(line)
|
||||
self.log.debug("stdout: %s", line)
|
||||
self._output.put(line)
|
||||
@@ -93,14 +93,14 @@ class BaseTask(threading.Thread):
|
||||
for proc in self.procs:
|
||||
proc.close_queue()
|
||||
self.failed = True
|
||||
self.finished = True
|
||||
self.log.exception(e)
|
||||
|
||||
finally:
|
||||
self.finished = True
|
||||
|
||||
def task_run(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
## TODO: If two clients are connected to the same task,
|
||||
## TODO: If two clients are connected to the same task,
|
||||
def logs_iter(self) -> Iterator[str]:
|
||||
with self.logs_lock:
|
||||
for proc in self.procs:
|
||||
@@ -120,7 +120,7 @@ class BaseTask(threading.Thread):
|
||||
break
|
||||
yield line
|
||||
|
||||
def register_cmds(self, num_cmds: int) -> Iterable[CmdState]:
|
||||
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
|
||||
for i in range(num_cmds):
|
||||
cmd = CmdState(self.log)
|
||||
self.procs.append(cmd)
|
||||
|
||||
@@ -74,20 +74,11 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
assert response.status_code == 200, "Failed to get vm status"
|
||||
|
||||
response = api.get(f"/api/vms/{uuid}/logs")
|
||||
print("=========FLAKE LOGS==========")
|
||||
assert isinstance(response.stream, SyncByteStream)
|
||||
for line in response.stream:
|
||||
assert line != b"", "Failed to get vm logs"
|
||||
print(line.decode("utf-8"), end="")
|
||||
print("=========END LOGS==========")
|
||||
assert response.status_code == 200, "Failed to get vm logs"
|
||||
|
||||
response = api.get(f"/api/vms/{uuid}/logs")
|
||||
assert isinstance(response.stream, SyncByteStream)
|
||||
print("=========VM LOGS==========")
|
||||
assert isinstance(response.stream, SyncByteStream)
|
||||
for line in response.stream:
|
||||
assert line != b"", "Failed to get vm logs"
|
||||
print(line.decode("utf-8"), end="")
|
||||
print(line.decode("utf-8"))
|
||||
print("=========END LOGS==========")
|
||||
assert response.status_code == 200, "Failed to get vm logs"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user