import logging import threading import time import types import uuid from collections.abc import Callable from dataclasses import dataclass, field from typing import IO, Any, Generic, ParamSpec, TypeVar from clan_cli.errors import ClanError log = logging.getLogger(__name__) # Why did we create a custom AsyncRuntime instead of using asyncio? # # The AsyncRuntime class allows us to run functions in separate threads for asynchronous # execution without requiring the use of the "async" keyword. By using this approach, # functions can gracefully handle cancellation by checking if get_async_ctx().cancel is True. # # There was some resistance to using asyncio, partly due to challenges we faced when # implementing it in our first web interface. Threads felt simpler and more familiar # for our use case. # # That said, asyncio is generally more efficient because it uses non-blocking I/O # and avoids issues with Python's Global Interpreter Lock (GIL), which can limit the # performance of threads in CPU-heavy workloads. # # Using threads works well for us because most of the time is spent waiting for commands # or external processes to finish, rather than performing heavy computing tasks. # # Note: Starting with Python 3.14, the GIL can be disabled to enable true parallelism. # However, disabling the GIL introduces a 10-40% performance cost for Python code # due to the overhead of additional locking. # Define generics for return type and call signature R = TypeVar("R") # Return type of the callable P = ParamSpec("P") # Parameters of the callable @dataclass class AsyncResult(Generic[R]): _result: R | Exception @property def error(self) -> Exception | None: """ Returns an error if the callable raised an exception. """ if isinstance(self._result, Exception): return self._result return None @property def result(self) -> R: """ Unwraps and returns the result if no exception occurred. Raises the exception otherwise. """ if isinstance(self._result, Exception): raise self._result return self._result @dataclass class AsyncContext: """ This class stores thread-local data. """ prefix: str | None = None # prefix for logging stdout: IO[bytes] | None = None # stdout of subprocesses stderr: IO[bytes] | None = None # stderr of subprocesses cancel: bool = False # Used to signal cancellation of task @dataclass class AsyncOpts: """ Options for the async_run function. """ tid: str | None = None check: bool = True async_ctx: AsyncContext = field(default_factory=AsyncContext) ASYNC_CTX_THREAD_LOCAL = threading.local() def is_async_cancelled() -> bool: """ Check if the current task has been cancelled. """ return get_async_ctx().cancel def get_async_ctx() -> AsyncContext: """ Retrieve the current AsyncContext, creating a new one if none exists. """ global ASYNC_CTX_THREAD_LOCAL if not hasattr(ASYNC_CTX_THREAD_LOCAL, "async_ctx"): ASYNC_CTX_THREAD_LOCAL.async_ctx = AsyncContext() return ASYNC_CTX_THREAD_LOCAL.async_ctx def set_async_ctx(ctx: AsyncContext) -> None: global ASYNC_CTX_THREAD_LOCAL ASYNC_CTX_THREAD_LOCAL.async_ctx = ctx class AsyncThread(threading.Thread, Generic[P, R]): function: Callable[P, R] args: Any kwargs: Any result: AsyncResult[R] | None finished: bool condition: threading.Condition async_opts: AsyncOpts def __init__( self, async_opts: AsyncOpts, condition: threading.Condition, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs, ) -> None: """ A threaded wrapper for running a function asynchronously. """ super().__init__() self.function = function self.args = args self.kwargs = kwargs self.result: AsyncResult[R] | None = None # Store the result or exception self.finished = False # Set to True after the thread finishes execution self.condition = condition # Shared condition variable self.async_opts = async_opts def run(self) -> None: """ Run the function in a separate thread. """ try: set_async_ctx(self.async_opts.async_ctx) self.result = AsyncResult(_result=self.function(*self.args, **self.kwargs)) except Exception as ex: self.result = AsyncResult(_result=ex) finally: self.finished = True # Acquire the condition lock before notifying with self.condition: self.condition.notify_all() # Notify waiting threads that this thread is done @dataclass class AsyncFuture(Generic[R]): _tid: str _runtime: "AsyncRuntime" def wait(self) -> AsyncResult[R]: """ Wait for the task to finish. """ if self._tid not in self._runtime.tasks: msg = f"No task with the name '{self._tid}' exists." raise ClanError(msg) thread = self._runtime.tasks[self._tid] thread.join() result = self.get_result() if result is None: msg = f"Task '{self._tid}' unexpectedly returned None." raise ClanError(msg) return result def get_result(self) -> AsyncResult[R] | None: """ Retrieve the result of a finished task and remove it from the task list. """ if self._tid not in self._runtime.tasks: msg = f"No task with the name '{self._tid}' exists." raise ClanError(msg) thread = self._runtime.tasks[self._tid] if not thread.finished: return None # Remove the task after retrieving the result result = thread.result del self._runtime.tasks[self._tid] if result is None: msg = f"The result for task '{self._tid}' is unexpectedly None." raise ClanError(msg) return result @dataclass class AsyncRuntime: tasks: dict[str, AsyncThread[Any, Any]] = field(default_factory=dict) condition: threading.Condition = field(default_factory=threading.Condition) def async_run( self, opts: AsyncOpts | None, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs, ) -> AsyncFuture[R]: """ Run the given function asynchronously in a thread with a specific name and arguments. The function's static typing is preserved. """ if opts is None: opts = AsyncOpts() if opts.tid is None: opts.tid = uuid.uuid4().hex if opts.tid in self.tasks: msg = f"A task with the name '{opts.tid}' is already running." raise ClanError(msg) # Create and start the new AsyncThread thread = AsyncThread(opts, self.condition, function, *args, **kwargs) self.tasks[opts.tid] = thread thread.start() return AsyncFuture(opts.tid, self) def join_all(self) -> None: """ Wait for all tasks to finish """ with self.condition: while any( not task.finished for task in self.tasks.values() ): # Check if any tasks are still running self.condition.wait() # Wait until a thread signals completion def check_all(self) -> None: """ Check if there where any errors """ err_count = 0 for name, task in self.tasks.items(): if task.finished and task.async_opts.check: assert task.result is not None error = task.result.error if log.isEnabledFor(logging.DEBUG): log.error( f"failed with error: {error}", extra={"command_prefix": name}, exc_info=error, ) else: log.error( f"failed with error: {error}", extra={"command_prefix": name} ) err_count += 1 if err_count > 0: msg = f"{err_count} hosts failed with an error. Check the logs above" raise ClanError(msg) def __enter__(self) -> "AsyncRuntime": """ Enter the runtime context related to this object. """ return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: """ Exit the runtime context related to this object. Sets async_ctx.cancel to True to signal cancellation. """ for name, task in self.tasks.items(): if not task.finished: task.async_opts.async_ctx.cancel = True log.debug(f"Canceling task {name}") # Example usage if __name__ == "__main__": runtime = AsyncRuntime() def add(a: int, b: int) -> int: return a + b def concatenate(a: str, b: str) -> str: time.sleep(1) msg = "Hello World" raise ClanError(msg) with runtime: p1 = runtime.async_run(None, add, 1, 2) p2 = runtime.async_run(None, concatenate, "Hello ", "World") add_result = p1.wait() print(add_result.result) # Output: 3 concat_result = p2.wait() print(concat_result.error) # Output: Hello World