improve task manager to report exceptions better
This commit is contained in:
@@ -4,125 +4,117 @@ import queue
|
|||||||
import select
|
import select
|
||||||
import shlex
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import traceback
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Iterator, Type, TypeVar
|
from typing import Any, Iterator, Type, TypeVar
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from .errors import ClanError
|
||||||
|
|
||||||
class CmdState:
|
|
||||||
|
class Command:
|
||||||
def __init__(self, log: logging.Logger) -> None:
|
def __init__(self, log: logging.Logger) -> None:
|
||||||
self.log: logging.Logger = log
|
self.log: logging.Logger = log
|
||||||
self.p: subprocess.Popen | None = None
|
self.p: subprocess.Popen | None = None
|
||||||
self.stdout: list[str] = []
|
|
||||||
self.stderr: list[str] = []
|
|
||||||
self._output: queue.SimpleQueue = queue.SimpleQueue()
|
self._output: queue.SimpleQueue = queue.SimpleQueue()
|
||||||
self.returncode: int | None = None
|
self.returncode: int | None = None
|
||||||
self.done: bool = False
|
self.done: bool = False
|
||||||
self.running: bool = False
|
self.lines: list[str] = []
|
||||||
self.cmd_str: str | None = None
|
|
||||||
self.workdir: str | None = None
|
|
||||||
|
|
||||||
def close_queue(self) -> None:
|
def close_queue(self) -> None:
|
||||||
if self.p is not None:
|
if self.p is not None:
|
||||||
self.returncode = self.p.returncode
|
self.returncode = self.p.returncode
|
||||||
self._output.put(None)
|
self._output.put(None)
|
||||||
self.running = False
|
|
||||||
self.done = True
|
self.done = True
|
||||||
|
|
||||||
def run(self, cmd: list[str]) -> None:
|
def run(self, cmd: list[str]) -> None:
|
||||||
self.running = True
|
self.running = True
|
||||||
try:
|
try:
|
||||||
self.cmd_str = shlex.join(cmd)
|
|
||||||
self.workdir = os.getcwd()
|
|
||||||
self.log.debug(f"Working directory: {self.workdir}")
|
|
||||||
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
||||||
self.p = subprocess.Popen(
|
self.p = subprocess.Popen(
|
||||||
cmd,
|
cmd,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
cwd=self.workdir,
|
|
||||||
)
|
)
|
||||||
|
assert self.p.stdout is not None and self.p.stderr is not None
|
||||||
|
os.set_blocking(self.p.stdout.fileno(), False)
|
||||||
|
os.set_blocking(self.p.stderr.fileno(), False)
|
||||||
|
|
||||||
while self.p.poll() is None:
|
while self.p.poll() is None:
|
||||||
# Check if stderr is ready to be read from
|
# Check if stderr is ready to be read from
|
||||||
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
|
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
|
||||||
if self.p.stderr in rlist:
|
for fd in rlist:
|
||||||
assert self.p.stderr is not None
|
try:
|
||||||
line = self.p.stderr.readline()
|
for line in fd:
|
||||||
if line != "":
|
self.log.debug("stdout: %s", line)
|
||||||
line = line.strip("\n")
|
self.lines.append(line)
|
||||||
self.stderr.append(line)
|
self._output.put(line)
|
||||||
self.log.debug("stderr: %s", line)
|
except BlockingIOError:
|
||||||
self._output.put(line + "\n")
|
continue
|
||||||
|
|
||||||
if self.p.stdout in rlist:
|
|
||||||
assert self.p.stdout is not None
|
|
||||||
line = self.p.stdout.readline()
|
|
||||||
if line != "":
|
|
||||||
line = line.strip("\n")
|
|
||||||
self.stdout.append(line)
|
|
||||||
self.log.debug("stdout: %s", line)
|
|
||||||
self._output.put(line + "\n")
|
|
||||||
|
|
||||||
if self.p.returncode != 0:
|
if self.p.returncode != 0:
|
||||||
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
|
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
|
||||||
|
|
||||||
self.log.debug("Successfully ran command")
|
self.log.debug("Successfully ran command")
|
||||||
finally:
|
finally:
|
||||||
self.close_queue()
|
self.close_queue()
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(threading.Thread):
|
class TaskStatus(str, Enum):
|
||||||
def __init__(self, uuid: UUID) -> None:
|
NOTSTARTED = "NOTSTARTED"
|
||||||
# calling parent class constructor
|
RUNNING = "RUNNING"
|
||||||
threading.Thread.__init__(self)
|
FINISHED = "FINISHED"
|
||||||
|
FAILED = "FAILED"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTask:
|
||||||
|
def __init__(self, uuid: UUID) -> None:
|
||||||
# constructor
|
# constructor
|
||||||
self.uuid: UUID = uuid
|
self.uuid: UUID = uuid
|
||||||
self.log = logging.getLogger(__name__)
|
self.log = logging.getLogger(__name__)
|
||||||
self.procs: list[CmdState] = []
|
self.procs: list[Command] = []
|
||||||
self.failed: bool = False
|
self.status = TaskStatus.NOTSTARTED
|
||||||
self.finished: bool = False
|
|
||||||
self.logs_lock = threading.Lock()
|
self.logs_lock = threading.Lock()
|
||||||
|
self.error: Exception | None = None
|
||||||
|
|
||||||
def run(self) -> None:
|
def _run(self) -> None:
|
||||||
|
self.status = TaskStatus.RUNNING
|
||||||
try:
|
try:
|
||||||
self.task_run()
|
self.run()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# FIXME: fix exception handling here
|
||||||
|
traceback.print_exception(*sys.exc_info())
|
||||||
for proc in self.procs:
|
for proc in self.procs:
|
||||||
proc.close_queue()
|
proc.close_queue()
|
||||||
self.failed = True
|
self.error = e
|
||||||
self.log.exception(e)
|
self.log.exception(e)
|
||||||
finally:
|
self.status = TaskStatus.FAILED
|
||||||
self.finished = True
|
else:
|
||||||
|
self.status = TaskStatus.FINISHED
|
||||||
|
|
||||||
def task_run(self) -> None:
|
def run(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
## TODO: If two clients are connected to the same task,
|
## TODO: If two clients are connected to the same task,
|
||||||
def logs_iter(self) -> Iterator[str]:
|
def log_lines(self) -> Iterator[str]:
|
||||||
with self.logs_lock:
|
with self.logs_lock:
|
||||||
for proc in self.procs:
|
for proc in self.procs:
|
||||||
if self.finished:
|
if self.status == TaskStatus.FINISHED:
|
||||||
self.log.debug("log iter: Task is finished")
|
return
|
||||||
break
|
# process has finished
|
||||||
if proc.done:
|
if proc.done:
|
||||||
for line in proc.stderr:
|
for line in proc.lines:
|
||||||
yield line + "\n"
|
yield line
|
||||||
for line in proc.stdout:
|
else:
|
||||||
yield line + "\n"
|
while line := proc._output.get():
|
||||||
continue
|
yield line
|
||||||
while True:
|
|
||||||
out = proc._output
|
|
||||||
line = out.get()
|
|
||||||
if line is None:
|
|
||||||
break
|
|
||||||
yield line
|
|
||||||
|
|
||||||
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
|
def register_commands(self, num_cmds: int) -> Iterator[Command]:
|
||||||
for i in range(num_cmds):
|
for _ in range(num_cmds):
|
||||||
cmd = CmdState(self.log)
|
cmd = Command(self.log)
|
||||||
self.procs.append(cmd)
|
self.procs.append(cmd)
|
||||||
|
|
||||||
for cmd in self.procs:
|
for cmd in self.procs:
|
||||||
@@ -165,6 +157,6 @@ def create_task(task_type: Type[T], *args: Any) -> T:
|
|||||||
uuid = uuid4()
|
uuid = uuid4()
|
||||||
|
|
||||||
task = task_type(uuid, *args)
|
task = task_type(uuid, *args)
|
||||||
|
threading.Thread(target=task._run).start()
|
||||||
POOL[uuid] = task
|
POOL[uuid] = task
|
||||||
task.start()
|
|
||||||
return task
|
return task
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from ..dirs import get_clan_flake_toplevel
|
from ..dirs import get_clan_flake_toplevel
|
||||||
from ..nix import nix_build, nix_shell
|
from ..nix import nix_build, nix_shell
|
||||||
from ..task_manager import BaseTask, CmdState, create_task
|
from ..task_manager import BaseTask, Command, create_task
|
||||||
from .inspect import VmConfig, inspect_vm
|
from .inspect import VmConfig, inspect_vm
|
||||||
|
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ class BuildVmTask(BaseTask):
|
|||||||
super().__init__(uuid)
|
super().__init__(uuid)
|
||||||
self.vm = vm
|
self.vm = vm
|
||||||
|
|
||||||
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
|
def get_vm_create_info(self, cmds: Iterator[Command]) -> dict:
|
||||||
clan_dir = self.vm.flake_url
|
clan_dir = self.vm.flake_url
|
||||||
machine = self.vm.flake_attr
|
machine = self.vm.flake_attr
|
||||||
cmd = next(cmds)
|
cmd = next(cmds)
|
||||||
@@ -30,13 +30,13 @@ class BuildVmTask(BaseTask):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
vm_json = "".join(cmd.stdout)
|
vm_json = "".join(cmd.lines)
|
||||||
self.log.debug(f"VM JSON path: {vm_json}")
|
self.log.debug(f"VM JSON path: {vm_json}")
|
||||||
with open(vm_json) as f:
|
with open(vm_json.strip()) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def task_run(self) -> None:
|
def run(self) -> None:
|
||||||
cmds = self.register_cmds(4)
|
cmds = self.register_commands(4)
|
||||||
|
|
||||||
machine = self.vm.flake_attr
|
machine = self.vm.flake_attr
|
||||||
self.log.debug(f"Creating VM for {machine}")
|
self.log.debug(f"Creating VM for {machine}")
|
||||||
@@ -121,7 +121,7 @@ def create_command(args: argparse.Namespace) -> None:
|
|||||||
vm = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
|
vm = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
|
||||||
|
|
||||||
task = create_vm(vm)
|
task = create_vm(vm)
|
||||||
for line in task.logs_iter():
|
for line in task.log_lines():
|
||||||
print(line, end="")
|
print(line, end="")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,9 +27,8 @@ async def inspect_vm(
|
|||||||
@router.get("/api/vms/{uuid}/status")
|
@router.get("/api/vms/{uuid}/status")
|
||||||
async def get_vm_status(uuid: UUID) -> VmStatusResponse:
|
async def get_vm_status(uuid: UUID) -> VmStatusResponse:
|
||||||
task = get_task(uuid)
|
task = get_task(uuid)
|
||||||
status: list[int | None] = list(map(lambda x: x.returncode, task.procs))
|
log.debug(msg=f"error: {task.error}, task.status: {task.status}")
|
||||||
log.debug(msg=f"returncodes: {status}. task.finished: {task.finished}")
|
return VmStatusResponse(status=task.status, error=str(task.error))
|
||||||
return VmStatusResponse(running=not task.finished, returncode=status)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/vms/{uuid}/logs")
|
@router.get("/api/vms/{uuid}/logs")
|
||||||
@@ -38,7 +37,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
|
|||||||
def stream_logs() -> Iterator[str]:
|
def stream_logs() -> Iterator[str]:
|
||||||
task = get_task(uuid)
|
task = get_task(uuid)
|
||||||
|
|
||||||
yield from task.logs_iter()
|
yield from task.log_lines()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
content=stream_logs(),
|
content=stream_logs(),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import List
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..task_manager import TaskStatus
|
||||||
from ..vms.inspect import VmConfig
|
from ..vms.inspect import VmConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -38,8 +39,8 @@ class SchemaResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class VmStatusResponse(BaseModel):
|
class VmStatusResponse(BaseModel):
|
||||||
returncode: list[int | None]
|
error: str | None
|
||||||
running: bool
|
status: TaskStatus
|
||||||
|
|
||||||
|
|
||||||
class VmCreateResponse(BaseModel):
|
class VmCreateResponse(BaseModel):
|
||||||
|
|||||||
@@ -58,14 +58,13 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
|
|||||||
print("=========VM LOGS==========")
|
print("=========VM LOGS==========")
|
||||||
assert isinstance(response.stream, SyncByteStream)
|
assert isinstance(response.stream, SyncByteStream)
|
||||||
for line in response.stream:
|
for line in response.stream:
|
||||||
assert line != b"", "Failed to get vm logs"
|
|
||||||
print(line.decode("utf-8"))
|
print(line.decode("utf-8"))
|
||||||
print("=========END LOGS==========")
|
print("=========END LOGS==========")
|
||||||
assert response.status_code == 200, "Failed to get vm logs"
|
assert response.status_code == 200, "Failed to get vm logs"
|
||||||
|
|
||||||
response = api.get(f"/api/vms/{uuid}/status")
|
response = api.get(f"/api/vms/{uuid}/status")
|
||||||
assert response.status_code == 200, "Failed to get vm status"
|
assert response.status_code == 200, "Failed to get vm status"
|
||||||
returncodes = response.json()["returncode"]
|
data = response.json()
|
||||||
assert response.json()["running"] is False, "VM is still running. Should be stopped"
|
assert (
|
||||||
for exit_code in returncodes:
|
data["status"] == "FINISHED"
|
||||||
assert exit_code == 0, "One VM failed with exit code != 0"
|
), f"Expected to be finished, but got {data['status']} ({data})"
|
||||||
|
|||||||
Reference in New Issue
Block a user