diff --git a/pkgs/clan-cli/clan_cli/task_manager.py b/pkgs/clan-cli/clan_cli/task_manager.py index 30ac9411b..e5392d1db 100644 --- a/pkgs/clan-cli/clan_cli/task_manager.py +++ b/pkgs/clan-cli/clan_cli/task_manager.py @@ -4,125 +4,117 @@ import queue import select import shlex import subprocess +import sys import threading +import traceback +from enum import Enum from typing import Any, Iterator, Type, TypeVar from uuid import UUID, uuid4 +from .errors import ClanError -class CmdState: + +class Command: def __init__(self, log: logging.Logger) -> None: self.log: logging.Logger = log self.p: subprocess.Popen | None = None - self.stdout: list[str] = [] - self.stderr: list[str] = [] self._output: queue.SimpleQueue = queue.SimpleQueue() self.returncode: int | None = None self.done: bool = False - self.running: bool = False - self.cmd_str: str | None = None - self.workdir: str | None = None + self.lines: list[str] = [] def close_queue(self) -> None: if self.p is not None: self.returncode = self.p.returncode self._output.put(None) - self.running = False self.done = True def run(self, cmd: list[str]) -> None: self.running = True try: - self.cmd_str = shlex.join(cmd) - self.workdir = os.getcwd() - self.log.debug(f"Working directory: {self.workdir}") self.log.debug(f"Running command: {shlex.join(cmd)}") self.p = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8", - cwd=self.workdir, ) + assert self.p.stdout is not None and self.p.stderr is not None + os.set_blocking(self.p.stdout.fileno(), False) + os.set_blocking(self.p.stderr.fileno(), False) while self.p.poll() is None: # Check if stderr is ready to be read from rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0) - if self.p.stderr in rlist: - assert self.p.stderr is not None - line = self.p.stderr.readline() - if line != "": - line = line.strip("\n") - self.stderr.append(line) - self.log.debug("stderr: %s", line) - self._output.put(line + "\n") - - if self.p.stdout in rlist: - assert self.p.stdout is not None - line = self.p.stdout.readline() - if line != "": - line = line.strip("\n") - self.stdout.append(line) - self.log.debug("stdout: %s", line) - self._output.put(line + "\n") + for fd in rlist: + try: + for line in fd: + self.log.debug("stdout: %s", line) + self.lines.append(line) + self._output.put(line) + except BlockingIOError: + continue if self.p.returncode != 0: - raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}") + raise ClanError(f"Failed to run command: {shlex.join(cmd)}") self.log.debug("Successfully ran command") finally: self.close_queue() -class BaseTask(threading.Thread): - def __init__(self, uuid: UUID) -> None: - # calling parent class constructor - threading.Thread.__init__(self) +class TaskStatus(str, Enum): + NOTSTARTED = "NOTSTARTED" + RUNNING = "RUNNING" + FINISHED = "FINISHED" + FAILED = "FAILED" + +class BaseTask: + def __init__(self, uuid: UUID) -> None: # constructor self.uuid: UUID = uuid self.log = logging.getLogger(__name__) - self.procs: list[CmdState] = [] - self.failed: bool = False - self.finished: bool = False + self.procs: list[Command] = [] + self.status = TaskStatus.NOTSTARTED self.logs_lock = threading.Lock() + self.error: Exception | None = None - def run(self) -> None: + def _run(self) -> None: + self.status = TaskStatus.RUNNING try: - self.task_run() + self.run() except Exception as e: + # FIXME: fix exception handling here + traceback.print_exception(*sys.exc_info()) for proc in self.procs: proc.close_queue() - self.failed = True + self.error = e self.log.exception(e) - finally: - self.finished = True + self.status = TaskStatus.FAILED + else: + self.status = TaskStatus.FINISHED - def task_run(self) -> None: + def run(self) -> None: raise NotImplementedError ## TODO: If two clients are connected to the same task, - def logs_iter(self) -> Iterator[str]: + def log_lines(self) -> Iterator[str]: with self.logs_lock: for proc in self.procs: - if self.finished: - self.log.debug("log iter: Task is finished") - break + if self.status == TaskStatus.FINISHED: + return + # process has finished if proc.done: - for line in proc.stderr: - yield line + "\n" - for line in proc.stdout: - yield line + "\n" - continue - while True: - out = proc._output - line = out.get() - if line is None: - break - yield line + for line in proc.lines: + yield line + else: + while line := proc._output.get(): + yield line - def register_cmds(self, num_cmds: int) -> Iterator[CmdState]: - for i in range(num_cmds): - cmd = CmdState(self.log) + def register_commands(self, num_cmds: int) -> Iterator[Command]: + for _ in range(num_cmds): + cmd = Command(self.log) self.procs.append(cmd) for cmd in self.procs: @@ -165,6 +157,6 @@ def create_task(task_type: Type[T], *args: Any) -> T: uuid = uuid4() task = task_type(uuid, *args) + threading.Thread(target=task._run).start() POOL[uuid] = task - task.start() return task diff --git a/pkgs/clan-cli/clan_cli/vms/create.py b/pkgs/clan-cli/clan_cli/vms/create.py index 9b870adfd..b235d9cb2 100644 --- a/pkgs/clan-cli/clan_cli/vms/create.py +++ b/pkgs/clan-cli/clan_cli/vms/create.py @@ -9,7 +9,7 @@ from uuid import UUID from ..dirs import get_clan_flake_toplevel from ..nix import nix_build, nix_shell -from ..task_manager import BaseTask, CmdState, create_task +from ..task_manager import BaseTask, Command, create_task from .inspect import VmConfig, inspect_vm @@ -18,7 +18,7 @@ class BuildVmTask(BaseTask): super().__init__(uuid) self.vm = vm - def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict: + def get_vm_create_info(self, cmds: Iterator[Command]) -> dict: clan_dir = self.vm.flake_url machine = self.vm.flake_attr cmd = next(cmds) @@ -30,13 +30,13 @@ class BuildVmTask(BaseTask): ] ) ) - vm_json = "".join(cmd.stdout) + vm_json = "".join(cmd.lines) self.log.debug(f"VM JSON path: {vm_json}") - with open(vm_json) as f: + with open(vm_json.strip()) as f: return json.load(f) - def task_run(self) -> None: - cmds = self.register_cmds(4) + def run(self) -> None: + cmds = self.register_commands(4) machine = self.vm.flake_attr self.log.debug(f"Creating VM for {machine}") @@ -121,7 +121,7 @@ def create_command(args: argparse.Namespace) -> None: vm = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine)) task = create_vm(vm) - for line in task.logs_iter(): + for line in task.log_lines(): print(line, end="") diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index 340b4c738..b76e5dbac 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -27,9 +27,8 @@ async def inspect_vm( @router.get("/api/vms/{uuid}/status") async def get_vm_status(uuid: UUID) -> VmStatusResponse: task = get_task(uuid) - status: list[int | None] = list(map(lambda x: x.returncode, task.procs)) - log.debug(msg=f"returncodes: {status}. task.finished: {task.finished}") - return VmStatusResponse(running=not task.finished, returncode=status) + log.debug(msg=f"error: {task.error}, task.status: {task.status}") + return VmStatusResponse(status=task.status, error=str(task.error)) @router.get("/api/vms/{uuid}/logs") @@ -38,7 +37,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse: def stream_logs() -> Iterator[str]: task = get_task(uuid) - yield from task.logs_iter() + yield from task.log_lines() return StreamingResponse( content=stream_logs(), diff --git a/pkgs/clan-cli/clan_cli/webui/schemas.py b/pkgs/clan-cli/clan_cli/webui/schemas.py index 578125395..85750e58a 100644 --- a/pkgs/clan-cli/clan_cli/webui/schemas.py +++ b/pkgs/clan-cli/clan_cli/webui/schemas.py @@ -3,6 +3,7 @@ from typing import List from pydantic import BaseModel, Field +from ..task_manager import TaskStatus from ..vms.inspect import VmConfig @@ -38,8 +39,8 @@ class SchemaResponse(BaseModel): class VmStatusResponse(BaseModel): - returncode: list[int | None] - running: bool + error: str | None + status: TaskStatus class VmCreateResponse(BaseModel): diff --git a/pkgs/clan-cli/tests/test_vms_api.py b/pkgs/clan-cli/tests/test_vms_api.py index 7f09dedad..32b74576c 100644 --- a/pkgs/clan-cli/tests/test_vms_api.py +++ b/pkgs/clan-cli/tests/test_vms_api.py @@ -58,14 +58,13 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None: print("=========VM LOGS==========") assert isinstance(response.stream, SyncByteStream) for line in response.stream: - assert line != b"", "Failed to get vm logs" print(line.decode("utf-8")) print("=========END LOGS==========") assert response.status_code == 200, "Failed to get vm logs" response = api.get(f"/api/vms/{uuid}/status") assert response.status_code == 200, "Failed to get vm status" - returncodes = response.json()["returncode"] - assert response.json()["running"] is False, "VM is still running. Should be stopped" - for exit_code in returncodes: - assert exit_code == 0, "One VM failed with exit code != 0" + data = response.json() + assert ( + data["status"] == "FINISHED" + ), f"Expected to be finished, but got {data['status']} ({data})"