Files
clan-core/pkgs/clan-cli/clan_cli/cmd.py

189 lines
4.9 KiB
Python

import logging
import os
import pty
import select
import shlex
import subprocess
import sys
import weakref
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from .custom_logger import get_caller
from .errors import ClanCmdError, CmdOut
glog = logging.getLogger(__name__)
class Log(Enum):
STDERR = 1
STDOUT = 2
BOTH = 3
NONE = 4
class TimeTable:
"""
This class is used to store the time taken by each command
and print it at the end of the program if env PERF=1 is set.
"""
def __init__(self) -> None:
self.table: dict[str, timedelta] = {}
weakref.finalize(self, self.table_print)
def table_print(self) -> None:
if os.getenv("PERF") != "1":
return
print("======== CMD TIMETABLE ========")
# Sort the table by time in descending order
sorted_table = sorted(
self.table.items(), key=lambda item: item[1], reverse=True
)
for k, v in sorted_table:
# Check if timedelta is greater than 1 second
if v.total_seconds() > 1:
# Print in red
print(f"\033[91mTook {v}s\033[0m for command: '{k}'")
else:
# Print in default color
print(f"Took {v} for command: '{k}'")
def add(self, cmd: str, time: timedelta) -> None:
if cmd in self.table:
self.table[cmd] += time
else:
self.table[cmd] = time
TIME_TABLE = TimeTable()
def run(
cmd: list[str],
*,
input: bytes | None = None, # noqa: A002
env: dict[str, str] | None = None,
cwd: Path = Path.cwd(),
log: Log = Log.STDERR,
check: bool = True,
error_msg: str | None = None,
) -> CmdOut:
if input:
glog.debug(
f"""$: echo "{input.decode('utf-8', 'replace')}" | {shlex.join(cmd)} \nCaller: {get_caller()}"""
)
else:
glog.debug(f"$: {shlex.join(cmd)} \nCaller: {get_caller()}")
# Create pseudo-terminals for stdout/stderr and stdin
stdout_master_fd, stdout_slave_fd = pty.openpty()
stderr_master_fd, stderr_slave_fd = pty.openpty()
tstart = datetime.now()
proc = subprocess.Popen(
cmd,
preexec_fn=os.setsid,
stdin=stdout_slave_fd,
stdout=stdout_slave_fd,
stderr=stderr_slave_fd,
close_fds=True,
env=env,
cwd=str(cwd),
)
os.close(stdout_slave_fd) # Close slave FD in parent
os.close(stderr_slave_fd) # Close slave FD in parent
stdout_file = sys.stdout
stderr_file = sys.stderr
stdout_buf = b""
stderr_buf = b""
if input:
written_b = os.write(stdout_master_fd, input)
if written_b != len(input):
raise ValueError("Could not write all input to subprocess")
rlist = [stdout_master_fd, stderr_master_fd]
def handle_fd(fd: int | None) -> bytes:
if fd and fd in r:
try:
read = os.read(fd, 4096)
if len(read) != 0:
return read
except OSError:
pass
rlist.remove(fd)
return b""
while len(rlist) != 0:
r, w, e = select.select(rlist, [], [], 0.1)
if len(r) == 0: # timeout in select
if proc.poll() is None:
continue
# Process has exited
break
ret = handle_fd(stdout_master_fd)
stdout_buf += ret
if ret and log in [Log.STDOUT, Log.BOTH]:
stdout_file.buffer.write(ret)
stdout_file.flush()
ret = handle_fd(stderr_master_fd)
stderr_buf += ret
if ret and log in [Log.STDERR, Log.BOTH]:
stderr_file.buffer.write(ret)
stderr_file.flush()
os.close(stdout_master_fd)
os.close(stderr_master_fd)
proc.wait()
tend = datetime.now()
global TIME_TABLE
TIME_TABLE.add(shlex.join(cmd), tend - tstart)
# Wait for the subprocess to finish
cmd_out = CmdOut(
stdout=stdout_buf.decode("utf-8", "replace"),
stderr=stderr_buf.decode("utf-8", "replace"),
cwd=cwd,
command=shlex.join(cmd),
returncode=proc.returncode,
msg=error_msg,
)
if check and proc.returncode != 0:
raise ClanCmdError(cmd_out)
return cmd_out
def run_no_stdout(
cmd: list[str],
*,
env: dict[str, str] | None = None,
cwd: Path = Path.cwd(),
log: Log = Log.STDERR,
check: bool = True,
error_msg: str | None = None,
) -> CmdOut:
"""
Like run, but automatically suppresses stdout, if not in DEBUG log level.
If in DEBUG log level the stdout of commands will be shown.
"""
if logging.getLogger(__name__.split(".")[0]).isEnabledFor(logging.DEBUG):
return run(cmd, env=env, log=log, check=check, error_msg=error_msg)
else:
log = Log.NONE
return run(cmd, env=env, log=log, check=check, error_msg=error_msg)