improve task manager to report exceptions better

This commit is contained in:
Jörg Thalheim
2023-10-04 16:44:26 +02:00
parent ff1fb784e7
commit 04ba80f614
5 changed files with 70 additions and 79 deletions

View File

@@ -4,125 +4,117 @@ import queue
import select import select
import shlex import shlex
import subprocess import subprocess
import sys
import threading import threading
import traceback
from enum import Enum
from typing import Any, Iterator, Type, TypeVar from typing import Any, Iterator, Type, TypeVar
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from .errors import ClanError
class CmdState:
class Command:
def __init__(self, log: logging.Logger) -> None: def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log self.log: logging.Logger = log
self.p: subprocess.Popen | None = None self.p: subprocess.Popen | None = None
self.stdout: list[str] = []
self.stderr: list[str] = []
self._output: queue.SimpleQueue = queue.SimpleQueue() self._output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None self.returncode: int | None = None
self.done: bool = False self.done: bool = False
self.running: bool = False self.lines: list[str] = []
self.cmd_str: str | None = None
self.workdir: str | None = None
def close_queue(self) -> None: def close_queue(self) -> None:
if self.p is not None: if self.p is not None:
self.returncode = self.p.returncode self.returncode = self.p.returncode
self._output.put(None) self._output.put(None)
self.running = False
self.done = True self.done = True
def run(self, cmd: list[str]) -> None: def run(self, cmd: list[str]) -> None:
self.running = True self.running = True
try: try:
self.cmd_str = shlex.join(cmd)
self.workdir = os.getcwd()
self.log.debug(f"Working directory: {self.workdir}")
self.log.debug(f"Running command: {shlex.join(cmd)}") self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen( self.p = subprocess.Popen(
cmd, cmd,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
encoding="utf-8", encoding="utf-8",
cwd=self.workdir,
) )
assert self.p.stdout is not None and self.p.stderr is not None
os.set_blocking(self.p.stdout.fileno(), False)
os.set_blocking(self.p.stderr.fileno(), False)
while self.p.poll() is None: while self.p.poll() is None:
# Check if stderr is ready to be read from # Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0) rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
if self.p.stderr in rlist: for fd in rlist:
assert self.p.stderr is not None try:
line = self.p.stderr.readline() for line in fd:
if line != "": self.log.debug("stdout: %s", line)
line = line.strip("\n") self.lines.append(line)
self.stderr.append(line) self._output.put(line)
self.log.debug("stderr: %s", line) except BlockingIOError:
self._output.put(line + "\n") continue
if self.p.stdout in rlist:
assert self.p.stdout is not None
line = self.p.stdout.readline()
if line != "":
line = line.strip("\n")
self.stdout.append(line)
self.log.debug("stdout: %s", line)
self._output.put(line + "\n")
if self.p.returncode != 0: if self.p.returncode != 0:
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}") raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command") self.log.debug("Successfully ran command")
finally: finally:
self.close_queue() self.close_queue()
class BaseTask(threading.Thread): class TaskStatus(str, Enum):
def __init__(self, uuid: UUID) -> None: NOTSTARTED = "NOTSTARTED"
# calling parent class constructor RUNNING = "RUNNING"
threading.Thread.__init__(self) FINISHED = "FINISHED"
FAILED = "FAILED"
class BaseTask:
def __init__(self, uuid: UUID) -> None:
# constructor # constructor
self.uuid: UUID = uuid self.uuid: UUID = uuid
self.log = logging.getLogger(__name__) self.log = logging.getLogger(__name__)
self.procs: list[CmdState] = [] self.procs: list[Command] = []
self.failed: bool = False self.status = TaskStatus.NOTSTARTED
self.finished: bool = False
self.logs_lock = threading.Lock() self.logs_lock = threading.Lock()
self.error: Exception | None = None
def run(self) -> None: def _run(self) -> None:
self.status = TaskStatus.RUNNING
try: try:
self.task_run() self.run()
except Exception as e: except Exception as e:
# FIXME: fix exception handling here
traceback.print_exception(*sys.exc_info())
for proc in self.procs: for proc in self.procs:
proc.close_queue() proc.close_queue()
self.failed = True self.error = e
self.log.exception(e) self.log.exception(e)
finally: self.status = TaskStatus.FAILED
self.finished = True else:
self.status = TaskStatus.FINISHED
def task_run(self) -> None: def run(self) -> None:
raise NotImplementedError raise NotImplementedError
## TODO: If two clients are connected to the same task, ## TODO: If two clients are connected to the same task,
def logs_iter(self) -> Iterator[str]: def log_lines(self) -> Iterator[str]:
with self.logs_lock: with self.logs_lock:
for proc in self.procs: for proc in self.procs:
if self.finished: if self.status == TaskStatus.FINISHED:
self.log.debug("log iter: Task is finished") return
break # process has finished
if proc.done: if proc.done:
for line in proc.stderr: for line in proc.lines:
yield line + "\n" yield line
for line in proc.stdout: else:
yield line + "\n" while line := proc._output.get():
continue yield line
while True:
out = proc._output
line = out.get()
if line is None:
break
yield line
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]: def register_commands(self, num_cmds: int) -> Iterator[Command]:
for i in range(num_cmds): for _ in range(num_cmds):
cmd = CmdState(self.log) cmd = Command(self.log)
self.procs.append(cmd) self.procs.append(cmd)
for cmd in self.procs: for cmd in self.procs:
@@ -165,6 +157,6 @@ def create_task(task_type: Type[T], *args: Any) -> T:
uuid = uuid4() uuid = uuid4()
task = task_type(uuid, *args) task = task_type(uuid, *args)
threading.Thread(target=task._run).start()
POOL[uuid] = task POOL[uuid] = task
task.start()
return task return task

View File

@@ -9,7 +9,7 @@ from uuid import UUID
from ..dirs import get_clan_flake_toplevel from ..dirs import get_clan_flake_toplevel
from ..nix import nix_build, nix_shell from ..nix import nix_build, nix_shell
from ..task_manager import BaseTask, CmdState, create_task from ..task_manager import BaseTask, Command, create_task
from .inspect import VmConfig, inspect_vm from .inspect import VmConfig, inspect_vm
@@ -18,7 +18,7 @@ class BuildVmTask(BaseTask):
super().__init__(uuid) super().__init__(uuid)
self.vm = vm self.vm = vm
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict: def get_vm_create_info(self, cmds: Iterator[Command]) -> dict:
clan_dir = self.vm.flake_url clan_dir = self.vm.flake_url
machine = self.vm.flake_attr machine = self.vm.flake_attr
cmd = next(cmds) cmd = next(cmds)
@@ -30,13 +30,13 @@ class BuildVmTask(BaseTask):
] ]
) )
) )
vm_json = "".join(cmd.stdout) vm_json = "".join(cmd.lines)
self.log.debug(f"VM JSON path: {vm_json}") self.log.debug(f"VM JSON path: {vm_json}")
with open(vm_json) as f: with open(vm_json.strip()) as f:
return json.load(f) return json.load(f)
def task_run(self) -> None: def run(self) -> None:
cmds = self.register_cmds(4) cmds = self.register_commands(4)
machine = self.vm.flake_attr machine = self.vm.flake_attr
self.log.debug(f"Creating VM for {machine}") self.log.debug(f"Creating VM for {machine}")
@@ -121,7 +121,7 @@ def create_command(args: argparse.Namespace) -> None:
vm = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine)) vm = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
task = create_vm(vm) task = create_vm(vm)
for line in task.logs_iter(): for line in task.log_lines():
print(line, end="") print(line, end="")

View File

@@ -27,9 +27,8 @@ async def inspect_vm(
@router.get("/api/vms/{uuid}/status") @router.get("/api/vms/{uuid}/status")
async def get_vm_status(uuid: UUID) -> VmStatusResponse: async def get_vm_status(uuid: UUID) -> VmStatusResponse:
task = get_task(uuid) task = get_task(uuid)
status: list[int | None] = list(map(lambda x: x.returncode, task.procs)) log.debug(msg=f"error: {task.error}, task.status: {task.status}")
log.debug(msg=f"returncodes: {status}. task.finished: {task.finished}") return VmStatusResponse(status=task.status, error=str(task.error))
return VmStatusResponse(running=not task.finished, returncode=status)
@router.get("/api/vms/{uuid}/logs") @router.get("/api/vms/{uuid}/logs")
@@ -38,7 +37,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
def stream_logs() -> Iterator[str]: def stream_logs() -> Iterator[str]:
task = get_task(uuid) task = get_task(uuid)
yield from task.logs_iter() yield from task.log_lines()
return StreamingResponse( return StreamingResponse(
content=stream_logs(), content=stream_logs(),

View File

@@ -3,6 +3,7 @@ from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..task_manager import TaskStatus
from ..vms.inspect import VmConfig from ..vms.inspect import VmConfig
@@ -38,8 +39,8 @@ class SchemaResponse(BaseModel):
class VmStatusResponse(BaseModel): class VmStatusResponse(BaseModel):
returncode: list[int | None] error: str | None
running: bool status: TaskStatus
class VmCreateResponse(BaseModel): class VmCreateResponse(BaseModel):

View File

@@ -58,14 +58,13 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
print("=========VM LOGS==========") print("=========VM LOGS==========")
assert isinstance(response.stream, SyncByteStream) assert isinstance(response.stream, SyncByteStream)
for line in response.stream: for line in response.stream:
assert line != b"", "Failed to get vm logs"
print(line.decode("utf-8")) print(line.decode("utf-8"))
print("=========END LOGS==========") print("=========END LOGS==========")
assert response.status_code == 200, "Failed to get vm logs" assert response.status_code == 200, "Failed to get vm logs"
response = api.get(f"/api/vms/{uuid}/status") response = api.get(f"/api/vms/{uuid}/status")
assert response.status_code == 200, "Failed to get vm status" assert response.status_code == 200, "Failed to get vm status"
returncodes = response.json()["returncode"] data = response.json()
assert response.json()["running"] is False, "VM is still running. Should be stopped" assert (
for exit_code in returncodes: data["status"] == "FINISHED"
assert exit_code == 0, "One VM failed with exit code != 0" ), f"Expected to be finished, but got {data['status']} ({data})"