From fe1a3f05419bba9f93eb1377ebaa439bde768043 Mon Sep 17 00:00:00 2001 From: lassulus Date: Wed, 4 Oct 2023 17:41:20 +0200 Subject: [PATCH] task_manager: fix race conditions --- pkgs/clan-cli/clan_cli/task_manager.py | 83 +++++++++++---------- pkgs/clan-cli/clan_cli/vms/create.py | 6 +- pkgs/clan-cli/clan_cli/webui/routers/vms.py | 3 +- 3 files changed, 48 insertions(+), 44 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/task_manager.py b/pkgs/clan-cli/clan_cli/task_manager.py index e5392d1db..d1fa34045 100644 --- a/pkgs/clan-cli/clan_cli/task_manager.py +++ b/pkgs/clan-cli/clan_cli/task_manager.py @@ -21,7 +21,8 @@ class Command: self._output: queue.SimpleQueue = queue.SimpleQueue() self.returncode: int | None = None self.done: bool = False - self.lines: list[str] = [] + self.stdout: list[str] = [] + self.stderr: list[str] = [] def close_queue(self) -> None: if self.p is not None: @@ -31,36 +32,36 @@ class Command: def run(self, cmd: list[str]) -> None: self.running = True - try: - self.log.debug(f"Running command: {shlex.join(cmd)}") - self.p = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - encoding="utf-8", - ) - assert self.p.stdout is not None and self.p.stderr is not None - os.set_blocking(self.p.stdout.fileno(), False) - os.set_blocking(self.p.stderr.fileno(), False) + self.log.debug(f"Running command: {shlex.join(cmd)}") + self.p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + assert self.p.stdout is not None and self.p.stderr is not None + os.set_blocking(self.p.stdout.fileno(), False) + os.set_blocking(self.p.stderr.fileno(), False) - while self.p.poll() is None: - # Check if stderr is ready to be read from - rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0) - for fd in rlist: - try: - for line in fd: - self.log.debug("stdout: %s", line) - self.lines.append(line) - self._output.put(line) - except BlockingIOError: - continue + while self.p.poll() is None: + # Check if stderr is ready to be read from + rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0) + for fd in rlist: + try: + for line in fd: + self.log.debug("stdout: %s", line) + if fd == self.p.stderr: + self.stderr.append(line) + else: + self.stdout.append(line) + self._output.put(line) + except BlockingIOError: + continue - if self.p.returncode != 0: - raise ClanError(f"Failed to run command: {shlex.join(cmd)}") + if self.p.returncode != 0: + raise ClanError(f"Failed to run command: {shlex.join(cmd)}") - self.log.debug("Successfully ran command") - finally: - self.close_queue() + self.log.debug("Successfully ran command") class TaskStatus(str, Enum): @@ -71,7 +72,7 @@ class TaskStatus(str, Enum): class BaseTask: - def __init__(self, uuid: UUID) -> None: + def __init__(self, uuid: UUID, num_cmds: int) -> None: # constructor self.uuid: UUID = uuid self.log = logging.getLogger(__name__) @@ -80,6 +81,10 @@ class BaseTask: self.logs_lock = threading.Lock() self.error: Exception | None = None + for _ in range(num_cmds): + cmd = Command(self.log) + self.procs.append(cmd) + def _run(self) -> None: self.status = TaskStatus.RUNNING try: @@ -87,13 +92,14 @@ class BaseTask: except Exception as e: # FIXME: fix exception handling here traceback.print_exception(*sys.exc_info()) - for proc in self.procs: - proc.close_queue() self.error = e self.log.exception(e) self.status = TaskStatus.FAILED else: self.status = TaskStatus.FINISHED + finally: + for proc in self.procs: + proc.close_queue() def run(self) -> None: raise NotImplementedError @@ -106,19 +112,16 @@ class BaseTask: return # process has finished if proc.done: - for line in proc.lines: + for line in proc.stdout: + yield line + for line in proc.stderr: yield line else: while line := proc._output.get(): yield line - def register_commands(self, num_cmds: int) -> Iterator[Command]: - for _ in range(num_cmds): - cmd = Command(self.log) - self.procs.append(cmd) - - for cmd in self.procs: - yield cmd + def commands(self) -> Iterator[Command]: + yield from self.procs # TODO: We need to test concurrency @@ -157,6 +160,6 @@ def create_task(task_type: Type[T], *args: Any) -> T: uuid = uuid4() task = task_type(uuid, *args) - threading.Thread(target=task._run).start() POOL[uuid] = task + threading.Thread(target=task._run).start() return task diff --git a/pkgs/clan-cli/clan_cli/vms/create.py b/pkgs/clan-cli/clan_cli/vms/create.py index 8235ece64..e754703e7 100644 --- a/pkgs/clan-cli/clan_cli/vms/create.py +++ b/pkgs/clan-cli/clan_cli/vms/create.py @@ -15,7 +15,7 @@ from .inspect import VmConfig, inspect_vm class BuildVmTask(BaseTask): def __init__(self, uuid: UUID, vm: VmConfig) -> None: - super().__init__(uuid) + super().__init__(uuid, num_cmds=4) self.vm = vm def get_vm_create_info(self, cmds: Iterator[Command]) -> dict: @@ -30,13 +30,13 @@ class BuildVmTask(BaseTask): ] ) ) - vm_json = "".join(cmd.lines) + vm_json = "".join(cmd.stdout) self.log.debug(f"VM JSON path: {vm_json}") with open(vm_json.strip()) as f: return json.load(f) def run(self) -> None: - cmds = self.register_commands(4) + cmds = self.commands() machine = self.vm.flake_attr self.log.debug(f"Creating VM for {machine}") diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index b76e5dbac..5ed46ecd3 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -28,7 +28,8 @@ async def inspect_vm( async def get_vm_status(uuid: UUID) -> VmStatusResponse: task = get_task(uuid) log.debug(msg=f"error: {task.error}, task.status: {task.status}") - return VmStatusResponse(status=task.status, error=str(task.error)) + error = str(task.error) if task.error is not None else None + return VmStatusResponse(status=task.status, error=error) @router.get("/api/vms/{uuid}/logs")