improve task manager to report exceptions better
This commit is contained in:
@@ -4,125 +4,117 @@ import queue
|
||||
import select
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Iterator, Type, TypeVar
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from .errors import ClanError
|
||||
|
||||
class CmdState:
|
||||
|
||||
class Command:
|
||||
def __init__(self, log: logging.Logger) -> None:
|
||||
self.log: logging.Logger = log
|
||||
self.p: subprocess.Popen | None = None
|
||||
self.stdout: list[str] = []
|
||||
self.stderr: list[str] = []
|
||||
self._output: queue.SimpleQueue = queue.SimpleQueue()
|
||||
self.returncode: int | None = None
|
||||
self.done: bool = False
|
||||
self.running: bool = False
|
||||
self.cmd_str: str | None = None
|
||||
self.workdir: str | None = None
|
||||
self.lines: list[str] = []
|
||||
|
||||
def close_queue(self) -> None:
|
||||
if self.p is not None:
|
||||
self.returncode = self.p.returncode
|
||||
self._output.put(None)
|
||||
self.running = False
|
||||
self.done = True
|
||||
|
||||
def run(self, cmd: list[str]) -> None:
|
||||
self.running = True
|
||||
try:
|
||||
self.cmd_str = shlex.join(cmd)
|
||||
self.workdir = os.getcwd()
|
||||
self.log.debug(f"Working directory: {self.workdir}")
|
||||
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
||||
self.p = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
cwd=self.workdir,
|
||||
)
|
||||
assert self.p.stdout is not None and self.p.stderr is not None
|
||||
os.set_blocking(self.p.stdout.fileno(), False)
|
||||
os.set_blocking(self.p.stderr.fileno(), False)
|
||||
|
||||
while self.p.poll() is None:
|
||||
# Check if stderr is ready to be read from
|
||||
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
|
||||
if self.p.stderr in rlist:
|
||||
assert self.p.stderr is not None
|
||||
line = self.p.stderr.readline()
|
||||
if line != "":
|
||||
line = line.strip("\n")
|
||||
self.stderr.append(line)
|
||||
self.log.debug("stderr: %s", line)
|
||||
self._output.put(line + "\n")
|
||||
|
||||
if self.p.stdout in rlist:
|
||||
assert self.p.stdout is not None
|
||||
line = self.p.stdout.readline()
|
||||
if line != "":
|
||||
line = line.strip("\n")
|
||||
self.stdout.append(line)
|
||||
self.log.debug("stdout: %s", line)
|
||||
self._output.put(line + "\n")
|
||||
for fd in rlist:
|
||||
try:
|
||||
for line in fd:
|
||||
self.log.debug("stdout: %s", line)
|
||||
self.lines.append(line)
|
||||
self._output.put(line)
|
||||
except BlockingIOError:
|
||||
continue
|
||||
|
||||
if self.p.returncode != 0:
|
||||
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
|
||||
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
|
||||
|
||||
self.log.debug("Successfully ran command")
|
||||
finally:
|
||||
self.close_queue()
|
||||
|
||||
|
||||
class BaseTask(threading.Thread):
|
||||
def __init__(self, uuid: UUID) -> None:
|
||||
# calling parent class constructor
|
||||
threading.Thread.__init__(self)
|
||||
class TaskStatus(str, Enum):
|
||||
NOTSTARTED = "NOTSTARTED"
|
||||
RUNNING = "RUNNING"
|
||||
FINISHED = "FINISHED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
|
||||
class BaseTask:
|
||||
def __init__(self, uuid: UUID) -> None:
|
||||
# constructor
|
||||
self.uuid: UUID = uuid
|
||||
self.log = logging.getLogger(__name__)
|
||||
self.procs: list[CmdState] = []
|
||||
self.failed: bool = False
|
||||
self.finished: bool = False
|
||||
self.procs: list[Command] = []
|
||||
self.status = TaskStatus.NOTSTARTED
|
||||
self.logs_lock = threading.Lock()
|
||||
self.error: Exception | None = None
|
||||
|
||||
def run(self) -> None:
|
||||
def _run(self) -> None:
|
||||
self.status = TaskStatus.RUNNING
|
||||
try:
|
||||
self.task_run()
|
||||
self.run()
|
||||
except Exception as e:
|
||||
# FIXME: fix exception handling here
|
||||
traceback.print_exception(*sys.exc_info())
|
||||
for proc in self.procs:
|
||||
proc.close_queue()
|
||||
self.failed = True
|
||||
self.error = e
|
||||
self.log.exception(e)
|
||||
finally:
|
||||
self.finished = True
|
||||
self.status = TaskStatus.FAILED
|
||||
else:
|
||||
self.status = TaskStatus.FINISHED
|
||||
|
||||
def task_run(self) -> None:
|
||||
def run(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
## TODO: If two clients are connected to the same task,
|
||||
def logs_iter(self) -> Iterator[str]:
|
||||
def log_lines(self) -> Iterator[str]:
|
||||
with self.logs_lock:
|
||||
for proc in self.procs:
|
||||
if self.finished:
|
||||
self.log.debug("log iter: Task is finished")
|
||||
break
|
||||
if self.status == TaskStatus.FINISHED:
|
||||
return
|
||||
# process has finished
|
||||
if proc.done:
|
||||
for line in proc.stderr:
|
||||
yield line + "\n"
|
||||
for line in proc.stdout:
|
||||
yield line + "\n"
|
||||
continue
|
||||
while True:
|
||||
out = proc._output
|
||||
line = out.get()
|
||||
if line is None:
|
||||
break
|
||||
yield line
|
||||
for line in proc.lines:
|
||||
yield line
|
||||
else:
|
||||
while line := proc._output.get():
|
||||
yield line
|
||||
|
||||
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
|
||||
for i in range(num_cmds):
|
||||
cmd = CmdState(self.log)
|
||||
def register_commands(self, num_cmds: int) -> Iterator[Command]:
|
||||
for _ in range(num_cmds):
|
||||
cmd = Command(self.log)
|
||||
self.procs.append(cmd)
|
||||
|
||||
for cmd in self.procs:
|
||||
@@ -165,6 +157,6 @@ def create_task(task_type: Type[T], *args: Any) -> T:
|
||||
uuid = uuid4()
|
||||
|
||||
task = task_type(uuid, *args)
|
||||
threading.Thread(target=task._run).start()
|
||||
POOL[uuid] = task
|
||||
task.start()
|
||||
return task
|
||||
|
||||
Reference in New Issue
Block a user