CLI: Restructured TaskManager and log collection

This commit is contained in:
Qubasa
2023-10-02 18:36:50 +02:00
parent c7bf416af4
commit ab6b96e516
3 changed files with 136 additions and 97 deletions

View File

@@ -1,10 +1,24 @@
import argparse import argparse
import asyncio import asyncio
from uuid import UUID
import threading
import queue
from ..dirs import get_clan_flake_toplevel from ..dirs import get_clan_flake_toplevel
from ..webui.routers import vms from ..webui.routers import vms
from ..webui.schemas import VmConfig from ..webui.schemas import VmConfig
from typing import Any, Iterator
from fastapi.responses import StreamingResponse
import pdb
def read_stream_response(stream: StreamingResponse) -> Iterator[Any]:
iterator = stream.body_iterator
while True:
try:
tem = asyncio.run(iterator.__anext__())
except StopAsyncIteration:
break
yield tem
def create(args: argparse.Namespace) -> None: def create(args: argparse.Namespace) -> None:
clan_dir = get_clan_flake_toplevel().as_posix() clan_dir = get_clan_flake_toplevel().as_posix()
@@ -18,6 +32,13 @@ def create(args: argparse.Namespace) -> None:
res = asyncio.run(vms.create_vm(vm)) res = asyncio.run(vms.create_vm(vm))
print(res.json()) print(res.json())
uuid = UUID(res.uuid)
res = asyncio.run(vms.get_vm_logs(uuid))
for line in read_stream_response(res):
print(line)
def register_create_parser(parser: argparse.ArgumentParser) -> None: def register_create_parser(parser: argparse.ArgumentParser) -> None:

View File

@@ -1,8 +1,9 @@
import json import json
import logging import logging
import tempfile import tempfile
import time
from pathlib import Path from pathlib import Path
from typing import Annotated, Iterator from typing import Annotated, Iterator, Iterable
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Body from fastapi import APIRouter, Body
@@ -10,7 +11,7 @@ from fastapi.responses import StreamingResponse
from ...nix import nix_build, nix_eval, nix_shell from ...nix import nix_build, nix_eval, nix_shell
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task from ..task_manager import BaseTask, get_task, register_task, CmdState
from .utils import run_cmd from .utils import run_cmd
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -38,10 +39,11 @@ class BuildVmTask(BaseTask):
super().__init__(uuid) super().__init__(uuid)
self.vm = vm self.vm = vm
def get_vm_create_info(self) -> dict: def get_vm_create_info(self, cmds: Iterable[CmdState]) -> dict:
clan_dir = self.vm.flake_url clan_dir = self.vm.flake_url
machine = self.vm.flake_attr machine = self.vm.flake_attr
cmd_state = self.run_cmd( cmd = next(cmds)
cmd.run(
nix_build( nix_build(
[ [
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this # f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
@@ -49,41 +51,48 @@ class BuildVmTask(BaseTask):
] ]
) )
) )
vm_json = "".join(cmd_state.stdout) vm_json = "".join(cmd.stdout)
self.log.debug(f"VM JSON path: {vm_json}") self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f: with open(vm_json) as f:
return json.load(f) return json.load(f)
def task_run(self) -> None: def task_run(self) -> None:
cmds = self.register_cmds(4)
machine = self.vm.flake_attr machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}") self.log.debug(f"Creating VM for {machine}")
vm_config = self.get_vm_create_info()
# TODO: We should get this from the vm argument
vm_config = self.get_vm_create_info(cmds)
with tempfile.TemporaryDirectory() as tmpdir_: with tempfile.TemporaryDirectory() as tmpdir_:
xchg_dir = Path(tmpdir_) / "xchg" xchg_dir = Path(tmpdir_) / "xchg"
xchg_dir.mkdir() xchg_dir.mkdir()
disk_img = f"{tmpdir_}/disk.img" disk_img = f"{tmpdir_}/disk.img"
cmd = nix_shell(
cmd = next(cmds)
cmd.run(nix_shell(
["qemu"], ["qemu"],
[ [
"qemu" "qemu-img", "qemu-img",
"create", "create",
"-f", "-f",
"raw", "raw",
disk_img, disk_img,
"1024M", "1024M",
], ],
) ))
self.run_cmd(cmd)
cmd = [ cmd = next(cmds)
cmd.run([
"mkfs.ext4", "mkfs.ext4",
"-L", "-L",
"nixos", "nixos",
disk_img, disk_img,
] ])
self.run_cmd(cmd)
cmd = nix_shell( cmd = next(cmds)
cmd.run(nix_shell(
["qemu"], ["qemu"],
[ [
# fmt: off # fmt: off
@@ -106,26 +115,7 @@ class BuildVmTask(BaseTask):
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0', "-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
# fmt: on # fmt: on
], ],
) ))
self.run_cmd(cmd)
# def run(self) -> None:
# try:
# self.log.debug(f"BuildVM with uuid {self.uuid} started")
# cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
# proc = self.run_cmd(cmd)
# self.log.debug(f"stdout: {proc.stdout}")
# vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
# self.log.debug(f"vm_path: {vm_path}")
# self.run_cmd([vm_path])
# self.finished = True
# except Exception as e:
# self.failed = True
# self.finished = True
# log.exception(e)
@router.post("/api/vms/inspect") @router.post("/api/vms/inspect")
@@ -154,21 +144,8 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]: def stream_logs() -> Iterator[str]:
task = get_task(uuid) task = get_task(uuid)
for proc in task.procs: for line in task.logs_iter():
if proc.done: yield line
log.debug("stream logs and proc is done")
for line in proc.stderr:
yield line + "\n"
for line in proc.stdout:
yield line + "\n"
continue
while True:
out = proc.output
line = out.get()
if line is None:
log.debug("stream logs and line is None")
break
yield line
return StreamingResponse( return StreamingResponse(
content=stream_logs(), content=stream_logs(),

View File

@@ -5,19 +5,72 @@ import select
import shlex import shlex
import subprocess import subprocess
import threading import threading
from typing import Any from typing import Any, Iterable, Iterator
from uuid import UUID, uuid4 from uuid import UUID, uuid4
class CmdState: class CmdState:
def __init__(self, proc: subprocess.Popen) -> None: def __init__(self, log: logging.Logger) -> None:
global LOOP self.log: logging.Logger = log
self.proc: subprocess.Popen = proc self.p: subprocess.Popen = None
self.stdout: list[str] = [] self.stdout: list[str] = []
self.stderr: list[str] = [] self.stderr: list[str] = []
self.output: queue.SimpleQueue = queue.SimpleQueue() self._output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None self.returncode: int | None = None
self.done: bool = False self.done: bool = False
self.running: bool = False
self.cmd_str: str | None = None
self.workdir: str | None = None
def close_queue(self) -> None:
if self.p is not None:
self.returncode = self.p.returncode
self._output.put(None)
self.running = False
self.done = True
def run(self, cmd: list[str]) -> None:
self.running = True
try:
self.cmd_str = shlex.join(cmd)
self.workdir = os.getcwd()
self.log.debug(f"Working directory: {self.workdir}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=self.workdir,
)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
if self.p.stderr in rlist:
assert self.p.stderr is not None
line = self.p.stderr.readline()
if line != "":
line = line.strip('\n')
self.stderr.append(line)
self.log.debug("stderr: %s", line)
self._output.put(line)
if self.p.stdout in rlist:
assert self.p.stdout is not None
line = self.p.stdout.readline()
if line != "":
line = line.strip('\n')
self.stdout.append(line)
self.log.debug("stdout: %s", line)
self._output.put(line)
if self.p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
finally:
self.close_queue()
class BaseTask(threading.Thread): class BaseTask(threading.Thread):
@@ -31,64 +84,52 @@ class BaseTask(threading.Thread):
self.procs: list[CmdState] = [] self.procs: list[CmdState] = []
self.failed: bool = False self.failed: bool = False
self.finished: bool = False self.finished: bool = False
self.logs_lock = threading.Lock()
def run(self) -> None: def run(self) -> None:
try: try:
self.task_run() self.task_run()
except Exception as e: except Exception as e:
for proc in self.procs:
proc.close_queue()
self.failed = True self.failed = True
self.log.exception(e)
finally:
self.finished = True self.finished = True
self.log.exception(e)
def task_run(self) -> None: def task_run(self) -> None:
raise NotImplementedError raise NotImplementedError
def run_cmd(self, cmd: list[str]) -> CmdState: ## TODO: If two clients are connected to the same task,
cwd = os.getcwd() def logs_iter(self) -> Iterator[str]:
self.log.debug(f"Working directory: {cwd}") with self.logs_lock:
self.log.debug(f"Running command: {shlex.join(cmd)}") for proc in self.procs:
p = subprocess.Popen( if self.finished:
cmd, self.log.debug("log iter: Task is finished")
stdout=subprocess.PIPE, break
stderr=subprocess.PIPE, if proc.done:
encoding="utf-8", for line in proc.stderr:
# shell=True, yield line
cwd=cwd, for line in proc.stdout:
) yield line
self.procs.append(CmdState(p)) continue
p_state = self.procs[-1] while True:
out = proc._output
line = out.get()
if line is None:
break
yield line
while p.poll() is None: def register_cmds(self, num_cmds: int) -> Iterable[CmdState]:
# Check if stderr is ready to be read from for i in range(num_cmds):
rlist, _, _ = select.select([p.stderr, p.stdout], [], [], 0) cmd = CmdState(self.log)
if p.stderr in rlist: self.procs.append(cmd)
assert p.stderr is not None
line = p.stderr.readline()
if line != "":
p_state.stderr.append(line.strip("\n"))
self.log.debug(f"stderr: {line}")
p_state.output.put(line)
if p.stdout in rlist: for cmd in self.procs:
assert p.stdout is not None yield cmd
line = p.stdout.readline()
if line != "":
p_state.stdout.append(line.strip("\n"))
self.log.debug(f"stdout: {line}")
p_state.output.put(line)
p_state.returncode = p.returncode
p_state.output.put(None)
p_state.done = True
if p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
return p_state
# TODO: We need to test concurrency
class TaskPool: class TaskPool:
def __init__(self) -> None: def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock() self.lock: threading.RLock = threading.RLock()