From 3a8ce96b438748209db4a40a11cd833d2974caba Mon Sep 17 00:00:00 2001 From: Qubasa Date: Mon, 2 Oct 2023 18:36:50 +0200 Subject: [PATCH] CLI: Restructured TaskManager and log collection --- pkgs/clan-cli/clan_cli/vms/create.py | 21 +++ pkgs/clan-cli/clan_cli/webui/routers/vms.py | 75 ++++------ pkgs/clan-cli/clan_cli/webui/task_manager.py | 137 ++++++++++++------- 3 files changed, 136 insertions(+), 97 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/vms/create.py b/pkgs/clan-cli/clan_cli/vms/create.py index a01c31640..d2481326d 100644 --- a/pkgs/clan-cli/clan_cli/vms/create.py +++ b/pkgs/clan-cli/clan_cli/vms/create.py @@ -1,10 +1,24 @@ import argparse import asyncio +from uuid import UUID +import threading +import queue from ..dirs import get_clan_flake_toplevel from ..webui.routers import vms from ..webui.schemas import VmConfig +from typing import Any, Iterator +from fastapi.responses import StreamingResponse +import pdb +def read_stream_response(stream: StreamingResponse) -> Iterator[Any]: + iterator = stream.body_iterator + while True: + try: + tem = asyncio.run(iterator.__anext__()) + except StopAsyncIteration: + break + yield tem def create(args: argparse.Namespace) -> None: clan_dir = get_clan_flake_toplevel().as_posix() @@ -18,6 +32,13 @@ def create(args: argparse.Namespace) -> None: res = asyncio.run(vms.create_vm(vm)) print(res.json()) + uuid = UUID(res.uuid) + + res = asyncio.run(vms.get_vm_logs(uuid)) + + for line in read_stream_response(res): + print(line) + def register_create_parser(parser: argparse.ArgumentParser) -> None: diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index a27344c04..a1de7cb1a 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -1,8 +1,9 @@ import json import logging import tempfile +import time from pathlib import Path -from typing import Annotated, Iterator +from typing import Annotated, Iterator, Iterable from uuid import UUID from fastapi import APIRouter, Body @@ -15,7 +16,7 @@ from clan_cli.webui.routers.flake import get_attrs from ...nix import nix_build, nix_eval from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse -from ..task_manager import BaseTask, get_task, register_task +from ..task_manager import BaseTask, get_task, register_task, CmdState from .utils import run_cmd log = logging.getLogger(__name__) @@ -43,10 +44,11 @@ class BuildVmTask(BaseTask): super().__init__(uuid) self.vm = vm - def get_vm_create_info(self) -> dict: + def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict: clan_dir = self.vm.flake_url machine = self.vm.flake_attr - cmd_state = self.run_cmd( + cmd = next(cmds) + cmd.run( nix_build( [ # f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this @@ -54,41 +56,48 @@ class BuildVmTask(BaseTask): ] ) ) - vm_json = "".join(cmd_state.stdout) + vm_json = "".join(cmd.stdout) self.log.debug(f"VM JSON path: {vm_json}") with open(vm_json) as f: return json.load(f) def task_run(self) -> None: + cmds = self.register_cmds(4) + machine = self.vm.flake_attr self.log.debug(f"Creating VM for {machine}") - vm_config = self.get_vm_create_info() + + # TODO: We should get this from the vm argument + vm_config = self.get_vm_create_info(cmds) + with tempfile.TemporaryDirectory() as tmpdir_: xchg_dir = Path(tmpdir_) / "xchg" xchg_dir.mkdir() disk_img = f"{tmpdir_}/disk.img" - cmd = nix_shell( + + cmd = next(cmds) + cmd.run(nix_shell( ["qemu"], [ - "qemu" "qemu-img", + "qemu-img", "create", "-f", "raw", disk_img, "1024M", ], - ) - self.run_cmd(cmd) + )) - cmd = [ + cmd = next(cmds) + cmd.run([ "mkfs.ext4", "-L", "nixos", disk_img, - ] - self.run_cmd(cmd) + ]) - cmd = nix_shell( + cmd = next(cmds) + cmd.run(nix_shell( ["qemu"], [ # fmt: off @@ -111,26 +120,7 @@ class BuildVmTask(BaseTask): "-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0', # fmt: on ], - ) - self.run_cmd(cmd) - - # def run(self) -> None: - # try: - # self.log.debug(f"BuildVM with uuid {self.uuid} started") - # cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url) - - # proc = self.run_cmd(cmd) - # self.log.debug(f"stdout: {proc.stdout}") - - # vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm" - # self.log.debug(f"vm_path: {vm_path}") - - # self.run_cmd([vm_path]) - # self.finished = True - # except Exception as e: - # self.failed = True - # self.finished = True - # log.exception(e) + )) @router.post("/api/vms/inspect") @@ -159,21 +149,8 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse: def stream_logs() -> Iterator[str]: task = get_task(uuid) - for proc in task.procs: - if proc.done: - log.debug("stream logs and proc is done") - for line in proc.stderr: - yield line + "\n" - for line in proc.stdout: - yield line + "\n" - continue - while True: - out = proc.output - line = out.get() - if line is None: - log.debug("stream logs and line is None") - break - yield line + for line in task.logs_iter(): + yield line return StreamingResponse( content=stream_logs(), diff --git a/pkgs/clan-cli/clan_cli/webui/task_manager.py b/pkgs/clan-cli/clan_cli/webui/task_manager.py index 58a5995a4..7e15930d2 100644 --- a/pkgs/clan-cli/clan_cli/webui/task_manager.py +++ b/pkgs/clan-cli/clan_cli/webui/task_manager.py @@ -5,19 +5,72 @@ import select import shlex import subprocess import threading -from typing import Any +from typing import Any, Iterable, Iterator from uuid import UUID, uuid4 class CmdState: - def __init__(self, proc: subprocess.Popen) -> None: - global LOOP - self.proc: subprocess.Popen = proc + def __init__(self, log: logging.Logger) -> None: + self.log: logging.Logger = log + self.p: subprocess.Popen = None self.stdout: list[str] = [] self.stderr: list[str] = [] - self.output: queue.SimpleQueue = queue.SimpleQueue() + self._output: queue.SimpleQueue = queue.SimpleQueue() self.returncode: int | None = None self.done: bool = False + self.running: bool = False + self.cmd_str: str | None = None + self.workdir: str | None = None + + def close_queue(self) -> None: + if self.p is not None: + self.returncode = self.p.returncode + self._output.put(None) + self.running = False + self.done = True + + def run(self, cmd: list[str]) -> None: + self.running = True + try: + self.cmd_str = shlex.join(cmd) + self.workdir = os.getcwd() + self.log.debug(f"Working directory: {self.workdir}") + self.log.debug(f"Running command: {shlex.join(cmd)}") + self.p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + cwd=self.workdir, + ) + + while self.p.poll() is None: + # Check if stderr is ready to be read from + rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0) + if self.p.stderr in rlist: + assert self.p.stderr is not None + line = self.p.stderr.readline() + if line != "": + line = line.strip('\n') + self.stderr.append(line) + self.log.debug("stderr: %s", line) + self._output.put(line) + + if self.p.stdout in rlist: + assert self.p.stdout is not None + line = self.p.stdout.readline() + if line != "": + line = line.strip('\n') + self.stdout.append(line) + self.log.debug("stdout: %s", line) + self._output.put(line) + + if self.p.returncode != 0: + raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}") + + self.log.debug("Successfully ran command") + finally: + self.close_queue() class BaseTask(threading.Thread): @@ -31,64 +84,52 @@ class BaseTask(threading.Thread): self.procs: list[CmdState] = [] self.failed: bool = False self.finished: bool = False + self.logs_lock = threading.Lock() def run(self) -> None: try: self.task_run() except Exception as e: + for proc in self.procs: + proc.close_queue() self.failed = True - self.log.exception(e) - finally: self.finished = True + self.log.exception(e) + def task_run(self) -> None: raise NotImplementedError - def run_cmd(self, cmd: list[str]) -> CmdState: - cwd = os.getcwd() - self.log.debug(f"Working directory: {cwd}") - self.log.debug(f"Running command: {shlex.join(cmd)}") - p = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - encoding="utf-8", - # shell=True, - cwd=cwd, - ) - self.procs.append(CmdState(p)) - p_state = self.procs[-1] + ## TODO: If two clients are connected to the same task, + def logs_iter(self) -> Iterator[str]: + with self.logs_lock: + for proc in self.procs: + if self.finished: + self.log.debug("log iter: Task is finished") + break + if proc.done: + for line in proc.stderr: + yield line + for line in proc.stdout: + yield line + continue + while True: + out = proc._output + line = out.get() + if line is None: + break + yield line - while p.poll() is None: - # Check if stderr is ready to be read from - rlist, _, _ = select.select([p.stderr, p.stdout], [], [], 0) - if p.stderr in rlist: - assert p.stderr is not None - line = p.stderr.readline() - if line != "": - p_state.stderr.append(line.strip("\n")) - self.log.debug(f"stderr: {line}") - p_state.output.put(line) + def register_cmds(self, num_cmds: int) -> Iterable[CmdState]: + for i in range(num_cmds): + cmd = CmdState(self.log) + self.procs.append(cmd) - if p.stdout in rlist: - assert p.stdout is not None - line = p.stdout.readline() - if line != "": - p_state.stdout.append(line.strip("\n")) - self.log.debug(f"stdout: {line}") - p_state.output.put(line) - - p_state.returncode = p.returncode - p_state.output.put(None) - p_state.done = True - - if p.returncode != 0: - raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}") - - self.log.debug("Successfully ran command") - return p_state + for cmd in self.procs: + yield cmd +# TODO: We need to test concurrency class TaskPool: def __init__(self) -> None: self.lock: threading.RLock = threading.RLock()