From c1c68ee1d87b6ed18a7a888bce5e222b6e49c1d5 Mon Sep 17 00:00:00 2001 From: Qubasa Date: Tue, 3 Oct 2023 11:51:31 +0200 Subject: [PATCH] Fixed failing tests --- pkgs/clan-cli/clan_cli/vms/create.py | 17 +++--- pkgs/clan-cli/clan_cli/webui/routers/vms.py | 64 +++++++++++--------- pkgs/clan-cli/clan_cli/webui/task_manager.py | 16 ++--- pkgs/clan-cli/tests/test_vms_api.py | 13 +--- 4 files changed, 52 insertions(+), 58 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/vms/create.py b/pkgs/clan-cli/clan_cli/vms/create.py index d2481326d..78f441d55 100644 --- a/pkgs/clan-cli/clan_cli/vms/create.py +++ b/pkgs/clan-cli/clan_cli/vms/create.py @@ -1,25 +1,25 @@ import argparse import asyncio +from typing import Any, Iterator from uuid import UUID -import threading -import queue + +from fastapi.responses import StreamingResponse from ..dirs import get_clan_flake_toplevel from ..webui.routers import vms from ..webui.schemas import VmConfig -from typing import Any, Iterator -from fastapi.responses import StreamingResponse -import pdb + def read_stream_response(stream: StreamingResponse) -> Iterator[Any]: iterator = stream.body_iterator while True: try: - tem = asyncio.run(iterator.__anext__()) + tem = asyncio.run(iterator.__anext__()) # type: ignore except StopAsyncIteration: break yield tem + def create(args: argparse.Namespace) -> None: clan_dir = get_clan_flake_toplevel().as_posix() vm = VmConfig( @@ -34,13 +34,12 @@ def create(args: argparse.Namespace) -> None: print(res.json()) uuid = UUID(res.uuid) - res = asyncio.run(vms.get_vm_logs(uuid)) + stream = asyncio.run(vms.get_vm_logs(uuid)) - for line in read_stream_response(res): + for line in read_stream_response(stream): print(line) - def register_create_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument("machine", type=str) parser.set_defaults(func=create) diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index a1de7cb1a..9b581bda5 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -1,9 +1,8 @@ import json import logging import tempfile -import time from pathlib import Path -from typing import Annotated, Iterator, Iterable +from typing import Annotated, Iterator from uuid import UUID from fastapi import APIRouter, Body @@ -16,7 +15,7 @@ from clan_cli.webui.routers.flake import get_attrs from ...nix import nix_build, nix_eval from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse -from ..task_manager import BaseTask, get_task, register_task, CmdState +from ..task_manager import BaseTask, CmdState, get_task, register_task from .utils import run_cmd log = logging.getLogger(__name__) @@ -44,7 +43,7 @@ class BuildVmTask(BaseTask): super().__init__(uuid) self.vm = vm - def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict: + def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict: clan_dir = self.vm.flake_url machine = self.vm.flake_attr cmd = next(cmds) @@ -76,31 +75,36 @@ class BuildVmTask(BaseTask): disk_img = f"{tmpdir_}/disk.img" cmd = next(cmds) - cmd.run(nix_shell( - ["qemu"], + cmd.run( + nix_shell( + ["qemu"], + [ + "qemu-img", + "create", + "-f", + "raw", + disk_img, + "1024M", + ], + ) + ) + + cmd = next(cmds) + cmd.run( [ - "qemu-img", - "create", - "-f", - "raw", + "mkfs.ext4", + "-L", + "nixos", disk_img, - "1024M", - ], - )) + ] + ) cmd = next(cmds) - cmd.run([ - "mkfs.ext4", - "-L", - "nixos", - disk_img, - ]) - - cmd = next(cmds) - cmd.run(nix_shell( - ["qemu"], - [ - # fmt: off + cmd.run( + nix_shell( + ["qemu"], + [ + # fmt: off "qemu-kvm", "-name", machine, "-m", f'{vm_config["memorySize"]}M', @@ -118,9 +122,10 @@ class BuildVmTask(BaseTask): "-kernel", f'{vm_config["toplevel"]}/kernel', "-initrd", vm_config["initrd"], "-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0', - # fmt: on - ], - )) + # fmt: on + ], + ) + ) @router.post("/api/vms/inspect") @@ -149,8 +154,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse: def stream_logs() -> Iterator[str]: task = get_task(uuid) - for line in task.logs_iter(): - yield line + yield from task.logs_iter() return StreamingResponse( content=stream_logs(), diff --git a/pkgs/clan-cli/clan_cli/webui/task_manager.py b/pkgs/clan-cli/clan_cli/webui/task_manager.py index 7e15930d2..c0913e60f 100644 --- a/pkgs/clan-cli/clan_cli/webui/task_manager.py +++ b/pkgs/clan-cli/clan_cli/webui/task_manager.py @@ -5,14 +5,14 @@ import select import shlex import subprocess import threading -from typing import Any, Iterable, Iterator +from typing import Any, Iterator from uuid import UUID, uuid4 class CmdState: def __init__(self, log: logging.Logger) -> None: self.log: logging.Logger = log - self.p: subprocess.Popen = None + self.p: subprocess.Popen | None = None self.stdout: list[str] = [] self.stderr: list[str] = [] self._output: queue.SimpleQueue = queue.SimpleQueue() @@ -51,7 +51,7 @@ class CmdState: assert self.p.stderr is not None line = self.p.stderr.readline() if line != "": - line = line.strip('\n') + line = line.strip("\n") self.stderr.append(line) self.log.debug("stderr: %s", line) self._output.put(line) @@ -60,7 +60,7 @@ class CmdState: assert self.p.stdout is not None line = self.p.stdout.readline() if line != "": - line = line.strip('\n') + line = line.strip("\n") self.stdout.append(line) self.log.debug("stdout: %s", line) self._output.put(line) @@ -93,14 +93,14 @@ class BaseTask(threading.Thread): for proc in self.procs: proc.close_queue() self.failed = True - self.finished = True self.log.exception(e) - + finally: + self.finished = True def task_run(self) -> None: raise NotImplementedError - ## TODO: If two clients are connected to the same task, + ## TODO: If two clients are connected to the same task, def logs_iter(self) -> Iterator[str]: with self.logs_lock: for proc in self.procs: @@ -120,7 +120,7 @@ class BaseTask(threading.Thread): break yield line - def register_cmds(self, num_cmds: int) -> Iterable[CmdState]: + def register_cmds(self, num_cmds: int) -> Iterator[CmdState]: for i in range(num_cmds): cmd = CmdState(self.log) self.procs.append(cmd) diff --git a/pkgs/clan-cli/tests/test_vms_api.py b/pkgs/clan-cli/tests/test_vms_api.py index 5bbc3c6d8..7904af19e 100644 --- a/pkgs/clan-cli/tests/test_vms_api.py +++ b/pkgs/clan-cli/tests/test_vms_api.py @@ -74,20 +74,11 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None: assert response.status_code == 200, "Failed to get vm status" response = api.get(f"/api/vms/{uuid}/logs") - print("=========FLAKE LOGS==========") - assert isinstance(response.stream, SyncByteStream) - for line in response.stream: - assert line != b"", "Failed to get vm logs" - print(line.decode("utf-8"), end="") - print("=========END LOGS==========") - assert response.status_code == 200, "Failed to get vm logs" - - response = api.get(f"/api/vms/{uuid}/logs") - assert isinstance(response.stream, SyncByteStream) print("=========VM LOGS==========") + assert isinstance(response.stream, SyncByteStream) for line in response.stream: assert line != b"", "Failed to get vm logs" - print(line.decode("utf-8"), end="") + print(line.decode("utf-8")) print("=========END LOGS==========") assert response.status_code == 200, "Failed to get vm logs"