diff --git a/checks/lib/container-driver/package.nix b/checks/lib/container-driver/package.nix index cbb47e2ee..d4dad4aa9 100644 --- a/checks/lib/container-driver/package.nix +++ b/checks/lib/container-driver/package.nix @@ -5,6 +5,8 @@ setuptools, util-linux, systemd, + colorama, + junit-xml, }: buildPythonApplication { pname = "test-driver"; @@ -12,6 +14,8 @@ buildPythonApplication { propagatedBuildInputs = [ util-linux systemd + colorama + junit-xml ] ++ extraPythonPackages python3Packages; nativeBuildInputs = [ setuptools ]; format = "pyproject"; diff --git a/checks/lib/container-driver/test_driver/__init__.py b/checks/lib/container-driver/test_driver/__init__.py index e45cf3db2..a74fab7b4 100644 --- a/checks/lib/container-driver/test_driver/__init__.py +++ b/checks/lib/container-driver/test_driver/__init__.py @@ -5,10 +5,13 @@ import subprocess import time import types from collections.abc import Callable +from contextlib import _GeneratorContextManager from pathlib import Path from tempfile import TemporaryDirectory from typing import Any +from .logger import AbstractLogger, CompositeLogger, TerminalLogger + class Error(Exception): pass @@ -42,12 +45,20 @@ def retry(fn: Callable, timeout: int = 900) -> None: class Machine: - def __init__(self, name: str, toplevel: Path, rootdir: Path, out_dir: str) -> None: + def __init__( + self, + name: str, + toplevel: Path, + logger: AbstractLogger, + rootdir: Path, + out_dir: str, + ) -> None: self.name = name self.toplevel = toplevel self.out_dir = out_dir self.process: subprocess.Popen | None = None self.rootdir: Path = rootdir + self.logger = logger def start(self) -> None: prepare_machine_root(self.name, self.rootdir) @@ -187,6 +198,15 @@ class Machine: ) return proc + def nested( + self, msg: str, attrs: dict[str, str] | None = None + ) -> _GeneratorContextManager: + if attrs is None: + attrs = {} + my_attrs = {"machine": self.name} + my_attrs.update(attrs) + return self.logger.nested(msg, my_attrs) + def systemctl(self, q: str) -> subprocess.CompletedProcess: """ Runs `systemctl` commands with optional support for @@ -203,6 +223,25 @@ class Machine: """ return self.execute(f"systemctl {q}") + def wait_until_succeeds(self, command: str, timeout: int = 900) -> str: + """ + Repeat a shell command with 1-second intervals until it succeeds. + Has a default timeout of 900 seconds which can be modified, e.g. + `wait_until_succeeds(cmd, timeout=10)`. See `execute` for details on + command execution. + Throws an exception on timeout. + """ + output = "" + + def check_success(_: Any) -> bool: + nonlocal output + result = self.execute(command, timeout=timeout) + return result.returncode == 0 + + with self.nested(f"waiting for success: {command}"): + retry(check_success, timeout) + return output + def wait_for_unit(self, unit: str, timeout: int = 900) -> None: """ Wait for a systemd unit to get into "active" state. @@ -260,10 +299,19 @@ def setup_filesystems() -> None: class Driver: - def __init__(self, containers: list[Path], testscript: str, out_dir: str) -> None: + logger: AbstractLogger + + def __init__( + self, + containers: list[Path], + logger: AbstractLogger, + testscript: str, + out_dir: str, + ) -> None: self.containers = containers self.testscript = testscript self.out_dir = out_dir + self.logger = logger setup_filesystems() self.tempdir = TemporaryDirectory() @@ -282,6 +330,7 @@ class Driver: toplevel=container, rootdir=tempdir_path / name, out_dir=self.out_dir, + logger=self.logger, ) ) @@ -367,9 +416,11 @@ def main() -> None: type=writeable_dir, ) args = arg_parser.parse_args() + logger = CompositeLogger([TerminalLogger()]) with Driver( - args.containers, - args.test_script.read_text(), - args.output_directory.resolve(), + containers=args.containers, + testscript=args.test_script.read_text(), + out_dir=args.output_directory.resolve(), + logger=logger, ) as driver: driver.run_tests() diff --git a/checks/lib/container-driver/test_driver/logger.py b/checks/lib/container-driver/test_driver/logger.py new file mode 100644 index 000000000..309801b84 --- /dev/null +++ b/checks/lib/container-driver/test_driver/logger.py @@ -0,0 +1,335 @@ +import atexit +import codecs +import os +import sys +import time +import unicodedata +from abc import ABC, abstractmethod +from collections.abc import Iterator +from contextlib import ExitStack, contextmanager +from pathlib import Path +from queue import Empty, Queue +from typing import Any +from xml.sax.saxutils import XMLGenerator +from xml.sax.xmlreader import AttributesImpl + +from colorama import Fore, Style +from junit_xml import TestCase, TestSuite + + +class AbstractLogger(ABC): + @abstractmethod + def log(self, message: str, attributes: dict[str, str] | None = None) -> None: + pass + + @abstractmethod + @contextmanager + def subtest( + self, name: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + pass + + @abstractmethod + @contextmanager + def nested( + self, message: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + pass + + @abstractmethod + def info(self, *args: Any, **kwargs: Any) -> None: # type: ignore + pass + + @abstractmethod + def warning(self, *args: Any, **kwargs: Any) -> None: # type: ignore + pass + + @abstractmethod + def error(self, *args: Any, **kwargs: Any) -> None: # type: ignore + pass + + @abstractmethod + def log_serial(self, message: str, machine: str) -> None: + pass + + @abstractmethod + def print_serial_logs(self, enable: bool) -> None: + pass + + +class JunitXMLLogger(AbstractLogger): + class TestCaseState: + def __init__(self) -> None: + self.stdout = "" + self.stderr = "" + self.failure = False + + def __init__(self, outfile: Path) -> None: + self.tests: dict[str, JunitXMLLogger.TestCaseState] = { + "main": self.TestCaseState() + } + self.currentSubtest = "main" + self.outfile: Path = outfile + self._print_serial_logs = True + atexit.register(self.close) + + def log(self, message: str, attributes: dict[str, str] | None = None) -> None: + self.tests[self.currentSubtest].stdout += message + os.linesep + + @contextmanager + def subtest( + self, name: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + old_test = self.currentSubtest + self.tests.setdefault(name, self.TestCaseState()) + self.currentSubtest = name + + yield + + self.currentSubtest = old_test + + @contextmanager + def nested( + self, message: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + self.log(message) + yield + + def info(self, *args: Any, **kwargs: Any) -> None: + self.tests[self.currentSubtest].stdout += args[0] + os.linesep + + def warning(self, *args: Any, **kwargs: Any) -> None: + self.tests[self.currentSubtest].stdout += args[0] + os.linesep + + def error(self, *args: Any, **kwargs: Any) -> None: + self.tests[self.currentSubtest].stderr += args[0] + os.linesep + self.tests[self.currentSubtest].failure = True + + def log_serial(self, message: str, machine: str) -> None: + if not self._print_serial_logs: + return + + self.log(f"{machine} # {message}") + + def print_serial_logs(self, enable: bool) -> None: + self._print_serial_logs = enable + + def close(self) -> None: + with Path.open(self.outfile, "w") as f: + test_cases = [] + for name, test_case_state in self.tests.items(): + tc = TestCase( + name, + stdout=test_case_state.stdout, + stderr=test_case_state.stderr, + ) + if test_case_state.failure: + tc.add_failure_info("test case failed") + + test_cases.append(tc) + ts = TestSuite("NixOS integration test", test_cases) + f.write(TestSuite.to_xml_string([ts])) + + +class CompositeLogger(AbstractLogger): + def __init__(self, logger_list: list[AbstractLogger]) -> None: + self.logger_list = logger_list + + def add_logger(self, logger: AbstractLogger) -> None: + self.logger_list.append(logger) + + def log(self, message: str, attributes: dict[str, str] | None = None) -> None: + for logger in self.logger_list: + logger.log(message, attributes) + + @contextmanager + def subtest( + self, name: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + with ExitStack() as stack: + for logger in self.logger_list: + stack.enter_context(logger.subtest(name, attributes)) + yield + + @contextmanager + def nested( + self, message: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + with ExitStack() as stack: + for logger in self.logger_list: + stack.enter_context(logger.nested(message, attributes)) + yield + + def info(self, *args: Any, **kwargs: Any) -> None: # type: ignore + for logger in self.logger_list: + logger.info(*args, **kwargs) + + def warning(self, *args: Any, **kwargs: Any) -> None: # type: ignore + for logger in self.logger_list: + logger.warning(*args, **kwargs) + + def error(self, *args: Any, **kwargs: Any) -> None: # type: ignore + for logger in self.logger_list: + logger.error(*args, **kwargs) + sys.exit(1) + + def print_serial_logs(self, enable: bool) -> None: + for logger in self.logger_list: + logger.print_serial_logs(enable) + + def log_serial(self, message: str, machine: str) -> None: + for logger in self.logger_list: + logger.log_serial(message, machine) + + +class TerminalLogger(AbstractLogger): + def __init__(self) -> None: + self._print_serial_logs = True + + def maybe_prefix(self, message: str, attributes: dict[str, str] | None) -> str: + if attributes and "machine" in attributes: + return f"{attributes['machine']}: {message}" + return message + + @staticmethod + def _eprint(*args: object, **kwargs: Any) -> None: + print(*args, file=sys.stderr, **kwargs) + + def log(self, message: str, attributes: dict[str, str] | None = None) -> None: + self._eprint(self.maybe_prefix(message, attributes)) + + @contextmanager + def subtest( + self, name: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + with self.nested("subtest: " + name, attributes): + yield + + @contextmanager + def nested( + self, message: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + self._eprint( + self.maybe_prefix( + Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes + ) + ) + + tic = time.time() + yield + toc = time.time() + self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)") + + def info(self, *args: Any, **kwargs: Any) -> None: # type: ignore + self.log(*args, **kwargs) + + def warning(self, *args: Any, **kwargs: Any) -> None: # type: ignore + self.log(*args, **kwargs) + + def error(self, *args: Any, **kwargs: Any) -> None: # type: ignore + self.log(*args, **kwargs) + + def print_serial_logs(self, enable: bool) -> None: + self._print_serial_logs = enable + + def log_serial(self, message: str, machine: str) -> None: + if not self._print_serial_logs: + return + + self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL) + + +class XMLLogger(AbstractLogger): + def __init__(self, outfile: str) -> None: + self.logfile_handle = codecs.open(outfile, "wb") + self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8") + self.queue: Queue[dict[str, str]] = Queue() + + self._print_serial_logs = True + + self.xml.startDocument() + self.xml.startElement("logfile", attrs=AttributesImpl({})) + + def close(self) -> None: + self.xml.endElement("logfile") + self.xml.endDocument() + self.logfile_handle.close() + + def sanitise(self, message: str) -> str: + return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C") + + def maybe_prefix( + self, message: str, attributes: dict[str, str] | None = None + ) -> str: + if attributes and "machine" in attributes: + return f"{attributes['machine']}: {message}" + return message + + def log_line(self, message: str, attributes: dict[str, str]) -> None: + self.xml.startElement("line", attrs=AttributesImpl(attributes)) + self.xml.characters(message) + self.xml.endElement("line") + + def info(self, *args: Any, **kwargs: Any) -> None: # type: ignore + self.log(*args, **kwargs) + + def warning(self, *args: Any, **kwargs: Any) -> None: # type: ignore + self.log(*args, **kwargs) + + def error(self, *args: Any, **kwargs: Any) -> None: # type: ignore + self.log(*args, **kwargs) + + def log(self, message: str, attributes: dict[str, str] | None = None) -> None: + if attributes is None: + attributes = {} + self.drain_log_queue() + self.log_line(message, attributes) + + def print_serial_logs(self, enable: bool) -> None: + self._print_serial_logs = enable + + def log_serial(self, message: str, machine: str) -> None: + if not self._print_serial_logs: + return + + self.enqueue({"msg": message, "machine": machine, "type": "serial"}) + + def enqueue(self, item: dict[str, str]) -> None: + self.queue.put(item) + + def drain_log_queue(self) -> None: + try: + while True: + item = self.queue.get_nowait() + msg = self.sanitise(item["msg"]) + del item["msg"] + self.log_line(msg, item) + except Empty: + pass + + @contextmanager + def subtest( + self, name: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + with self.nested("subtest: " + name, attributes): + yield + + @contextmanager + def nested( + self, message: str, attributes: dict[str, str] | None = None + ) -> Iterator[None]: + if attributes is None: + attributes = {} + self.xml.startElement("nest", attrs=AttributesImpl({})) + self.xml.startElement("head", attrs=AttributesImpl(attributes)) + self.xml.characters(message) + self.xml.endElement("head") + + tic = time.time() + self.drain_log_queue() + yield + self.drain_log_queue() + toc = time.time() + self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)") + + self.xml.endElement("nest")