Files
clan-core/pkgs/clan-cli/clan_cli/task_manager.py
2023-10-05 17:37:33 +02:00

168 lines
4.9 KiB
Python

import logging
import os
import queue
import select
import shlex
import subprocess
import sys
import threading
import traceback
from enum import Enum
from typing import Any, Iterator, Optional, Type, TypeVar
from uuid import UUID, uuid4
from .errors import ClanError
class Command:
def __init__(self, log: logging.Logger) -> None:
self.log: logging.Logger = log
self.p: subprocess.Popen | None = None
self._output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None
self.done: bool = False
self.stdout: list[str] = []
self.stderr: list[str] = []
def close_queue(self) -> None:
if self.p is not None:
self.returncode = self.p.returncode
self._output.put(None)
self.done = True
def run(self, cmd: list[str], env: Optional[dict[str, str]] = None) -> None:
self.running = True
self.log.debug(f"Running command: {shlex.join(cmd)}")
self.p = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
env=env,
)
assert self.p.stdout is not None and self.p.stderr is not None
os.set_blocking(self.p.stdout.fileno(), False)
os.set_blocking(self.p.stderr.fileno(), False)
while self.p.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
for fd in rlist:
try:
for line in fd:
if fd == self.p.stderr:
print(f"[{cmd[0]}] stderr: {line}")
self.stderr.append(line)
else:
print(f"[{cmd[0]}] stdout: {line}")
self.stdout.append(line)
self._output.put(line)
except BlockingIOError:
continue
if self.p.returncode != 0:
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
self.log.debug("Successfully ran command")
class TaskStatus(str, Enum):
NOTSTARTED = "NOTSTARTED"
RUNNING = "RUNNING"
FINISHED = "FINISHED"
FAILED = "FAILED"
class BaseTask:
def __init__(self, uuid: UUID, num_cmds: int) -> None:
# constructor
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[Command] = []
self.status = TaskStatus.NOTSTARTED
self.logs_lock = threading.Lock()
self.error: Exception | None = None
for _ in range(num_cmds):
cmd = Command(self.log)
self.procs.append(cmd)
def _run(self) -> None:
self.status = TaskStatus.RUNNING
try:
self.run()
except Exception as e:
# FIXME: fix exception handling here
traceback.print_exception(*sys.exc_info())
self.error = e
self.log.exception(e)
self.status = TaskStatus.FAILED
else:
self.status = TaskStatus.FINISHED
finally:
for proc in self.procs:
proc.close_queue()
def run(self) -> None:
raise NotImplementedError
## TODO: If two clients are connected to the same task,
def log_lines(self) -> Iterator[str]:
with self.logs_lock:
for proc in self.procs:
if self.status == TaskStatus.FINISHED:
return
# process has finished
if proc.done:
for line in proc.stdout:
yield line
for line in proc.stderr:
yield line
else:
while line := proc._output.get():
yield line
def commands(self) -> Iterator[Command]:
yield from self.procs
# TODO: We need to test concurrency
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: UUID) -> BaseTask:
with self.lock:
return self.pool[uuid]
def __setitem__(self, uuid: UUID, task: BaseTask) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"Task with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = task
POOL: TaskPool = TaskPool()
def get_task(uuid: UUID) -> BaseTask:
global POOL
return POOL[uuid]
T = TypeVar("T", bound="BaseTask")
def create_task(task_type: Type[T], *args: Any) -> T:
global POOL
uuid = uuid4()
task = task_type(uuid, *args)
POOL[uuid] = task
threading.Thread(target=task._run).start()
return task