189 lines
4.9 KiB
Python
189 lines
4.9 KiB
Python
import logging
|
|
import os
|
|
import pty
|
|
import select
|
|
import shlex
|
|
import subprocess
|
|
import sys
|
|
import weakref
|
|
from datetime import datetime, timedelta
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
|
|
from .custom_logger import get_caller
|
|
from .errors import ClanCmdError, CmdOut
|
|
|
|
glog = logging.getLogger(__name__)
|
|
|
|
|
|
class Log(Enum):
|
|
STDERR = 1
|
|
STDOUT = 2
|
|
BOTH = 3
|
|
NONE = 4
|
|
|
|
|
|
class TimeTable:
|
|
"""
|
|
This class is used to store the time taken by each command
|
|
and print it at the end of the program if env PERF=1 is set.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.table: dict[str, timedelta] = {}
|
|
weakref.finalize(self, self.table_print)
|
|
|
|
def table_print(self) -> None:
|
|
if os.getenv("PERF") != "1":
|
|
return
|
|
print("======== CMD TIMETABLE ========")
|
|
|
|
# Sort the table by time in descending order
|
|
sorted_table = sorted(
|
|
self.table.items(), key=lambda item: item[1], reverse=True
|
|
)
|
|
|
|
for k, v in sorted_table:
|
|
# Check if timedelta is greater than 1 second
|
|
if v.total_seconds() > 1:
|
|
# Print in red
|
|
print(f"\033[91mTook {v}s\033[0m for command: '{k}'")
|
|
else:
|
|
# Print in default color
|
|
print(f"Took {v} for command: '{k}'")
|
|
|
|
def add(self, cmd: str, time: timedelta) -> None:
|
|
if cmd in self.table:
|
|
self.table[cmd] += time
|
|
else:
|
|
self.table[cmd] = time
|
|
|
|
|
|
TIME_TABLE = TimeTable()
|
|
|
|
|
|
def run(
|
|
cmd: list[str],
|
|
*,
|
|
input: bytes | None = None, # noqa: A002
|
|
env: dict[str, str] | None = None,
|
|
cwd: Path = Path.cwd(),
|
|
log: Log = Log.STDERR,
|
|
check: bool = True,
|
|
error_msg: str | None = None,
|
|
) -> CmdOut:
|
|
if input:
|
|
glog.debug(
|
|
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
|
|
)
|
|
else:
|
|
glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
|
|
|
|
# Create pseudo-terminals for stdout/stderr and stdin
|
|
stdout_master_fd, stdout_slave_fd = pty.openpty()
|
|
stderr_master_fd, stderr_slave_fd = pty.openpty()
|
|
|
|
tstart = datetime.now()
|
|
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
preexec_fn=os.setsid,
|
|
stdin=stdout_slave_fd,
|
|
stdout=stdout_slave_fd,
|
|
stderr=stderr_slave_fd,
|
|
close_fds=True,
|
|
env=env,
|
|
cwd=str(cwd),
|
|
)
|
|
|
|
os.close(stdout_slave_fd) # Close slave FD in parent
|
|
os.close(stderr_slave_fd) # Close slave FD in parent
|
|
|
|
stdout_file = sys.stdout
|
|
stderr_file = sys.stderr
|
|
stdout_buf = b""
|
|
stderr_buf = b""
|
|
|
|
if input:
|
|
written_b = os.write(stdout_master_fd, input)
|
|
|
|
if written_b != len(input):
|
|
raise ValueError("Could not write all input to subprocess")
|
|
|
|
rlist = [stdout_master_fd, stderr_master_fd]
|
|
|
|
def handle_fd(fd: int | None) -> bytes:
|
|
if fd and fd in r:
|
|
try:
|
|
read = os.read(fd, 4096)
|
|
if len(read) != 0:
|
|
return read
|
|
except OSError:
|
|
pass
|
|
rlist.remove(fd)
|
|
return b""
|
|
|
|
while len(rlist) != 0:
|
|
r, w, e = select.select(rlist, [], [], 0.1)
|
|
if len(r) == 0: # timeout in select
|
|
if proc.poll() is None:
|
|
continue
|
|
# Process has exited
|
|
break
|
|
|
|
ret = handle_fd(stdout_master_fd)
|
|
stdout_buf += ret
|
|
if ret and log in [Log.STDOUT, Log.BOTH]:
|
|
stdout_file.buffer.write(ret)
|
|
stdout_file.flush()
|
|
|
|
ret = handle_fd(stderr_master_fd)
|
|
stderr_buf += ret
|
|
if ret and log in [Log.STDERR, Log.BOTH]:
|
|
stderr_file.buffer.write(ret)
|
|
stderr_file.flush()
|
|
|
|
os.close(stdout_master_fd)
|
|
os.close(stderr_master_fd)
|
|
|
|
proc.wait()
|
|
|
|
tend = datetime.now()
|
|
global TIME_TABLE
|
|
TIME_TABLE.add(shlex.join(cmd), tend - tstart)
|
|
|
|
# Wait for the subprocess to finish
|
|
cmd_out = CmdOut(
|
|
stdout=stdout_buf.decode("utf-8", "replace"),
|
|
stderr=stderr_buf.decode("utf-8", "replace"),
|
|
cwd=cwd,
|
|
command=shlex.join(cmd),
|
|
returncode=proc.returncode,
|
|
msg=error_msg,
|
|
)
|
|
|
|
if check and proc.returncode != 0:
|
|
raise ClanCmdError(cmd_out)
|
|
|
|
return cmd_out
|
|
|
|
|
|
def run_no_stdout(
|
|
cmd: list[str],
|
|
*,
|
|
env: dict[str, str] | None = None,
|
|
cwd: Path = Path.cwd(),
|
|
log: Log = Log.STDERR,
|
|
check: bool = True,
|
|
error_msg: str | None = None,
|
|
) -> CmdOut:
|
|
"""
|
|
Like run, but automatically suppresses stdout, if not in DEBUG log level.
|
|
If in DEBUG log level the stdout of commands will be shown.
|
|
"""
|
|
if logging.getLogger(__name__.split(".")[0]).isEnabledFor(logging.DEBUG):
|
|
return run(cmd, env=env, log=log, check=check, error_msg=error_msg)
|
|
else:
|
|
log = Log.NONE
|
|
return run(cmd, env=env, log=log, check=check, error_msg=error_msg)
|