313 lines
9.3 KiB
Python
313 lines
9.3 KiB
Python
import logging
|
|
import threading
|
|
import time
|
|
import types
|
|
import uuid
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from typing import IO, Any, Generic, ParamSpec, TypeVar
|
|
|
|
from clan_cli.errors import ClanError
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# Why did we create a custom AsyncRuntime instead of using asyncio?
|
|
#
|
|
# The AsyncRuntime class allows us to run functions in separate threads for asynchronous
|
|
# execution without requiring the use of the "async" keyword. By using this approach,
|
|
# functions can gracefully handle cancellation by checking if get_async_ctx().cancel is True.
|
|
#
|
|
# There was some resistance to using asyncio, partly due to challenges we faced when
|
|
# implementing it in our first web interface. Threads felt simpler and more familiar
|
|
# for our use case.
|
|
#
|
|
# That said, asyncio is generally more efficient because it uses non-blocking I/O
|
|
# and avoids issues with Python's Global Interpreter Lock (GIL), which can limit the
|
|
# performance of threads in CPU-heavy workloads.
|
|
#
|
|
# Using threads works well for us because most of the time is spent waiting for commands
|
|
# or external processes to finish, rather than performing heavy computing tasks.
|
|
#
|
|
# Note: Starting with Python 3.14, the GIL can be disabled to enable true parallelism.
|
|
# However, disabling the GIL introduces a 10-40% performance cost for Python code
|
|
# due to the overhead of additional locking.
|
|
|
|
# Define generics for return type and call signature
|
|
R = TypeVar("R") # Return type of the callable
|
|
P = ParamSpec("P") # Parameters of the callable
|
|
|
|
|
|
@dataclass
|
|
class AsyncResult(Generic[R]):
|
|
_result: R | Exception
|
|
|
|
@property
|
|
def error(self) -> Exception | None:
|
|
"""
|
|
Returns an error if the callable raised an exception.
|
|
"""
|
|
if isinstance(self._result, Exception):
|
|
return self._result
|
|
return None
|
|
|
|
@property
|
|
def result(self) -> R:
|
|
"""
|
|
Unwraps and returns the result if no exception occurred.
|
|
Raises the exception otherwise.
|
|
"""
|
|
if isinstance(self._result, Exception):
|
|
raise self._result
|
|
return self._result
|
|
|
|
|
|
@dataclass
|
|
class AsyncContext:
|
|
"""
|
|
This class stores thread-local data.
|
|
"""
|
|
|
|
prefix: str | None = None # prefix for logging
|
|
stdout: IO[bytes] | None = None # stdout of subprocesses
|
|
stderr: IO[bytes] | None = None # stderr of subprocesses
|
|
cancel: bool = False # Used to signal cancellation of task
|
|
|
|
|
|
@dataclass
|
|
class AsyncOpts:
|
|
"""
|
|
Options for the async_run function.
|
|
"""
|
|
|
|
tid: str | None = None
|
|
check: bool = True
|
|
async_ctx: AsyncContext = field(default_factory=AsyncContext)
|
|
|
|
|
|
ASYNC_CTX_THREAD_LOCAL = threading.local()
|
|
|
|
|
|
def is_async_cancelled() -> bool:
|
|
"""
|
|
Check if the current task has been cancelled.
|
|
"""
|
|
return get_async_ctx().cancel
|
|
|
|
|
|
def get_async_ctx() -> AsyncContext:
|
|
"""
|
|
Retrieve the current AsyncContext, creating a new one if none exists.
|
|
"""
|
|
global ASYNC_CTX_THREAD_LOCAL
|
|
|
|
if not hasattr(ASYNC_CTX_THREAD_LOCAL, "async_ctx"):
|
|
ASYNC_CTX_THREAD_LOCAL.async_ctx = AsyncContext()
|
|
return ASYNC_CTX_THREAD_LOCAL.async_ctx
|
|
|
|
|
|
def set_async_ctx(ctx: AsyncContext) -> None:
|
|
global ASYNC_CTX_THREAD_LOCAL
|
|
ASYNC_CTX_THREAD_LOCAL.async_ctx = ctx
|
|
|
|
|
|
class AsyncThread(threading.Thread, Generic[P, R]):
|
|
function: Callable[P, R]
|
|
args: Any
|
|
kwargs: Any
|
|
result: AsyncResult[R] | None
|
|
finished: bool
|
|
condition: threading.Condition
|
|
async_opts: AsyncOpts
|
|
|
|
def __init__(
|
|
self,
|
|
async_opts: AsyncOpts,
|
|
condition: threading.Condition,
|
|
function: Callable[P, R],
|
|
*args: P.args,
|
|
**kwargs: P.kwargs,
|
|
) -> None:
|
|
"""
|
|
A threaded wrapper for running a function asynchronously.
|
|
"""
|
|
super().__init__()
|
|
self.function = function
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
self.result: AsyncResult[R] | None = None # Store the result or exception
|
|
self.finished = False # Set to True after the thread finishes execution
|
|
self.condition = condition # Shared condition variable
|
|
self.async_opts = async_opts
|
|
|
|
def run(self) -> None:
|
|
"""
|
|
Run the function in a separate thread.
|
|
"""
|
|
try:
|
|
set_async_ctx(self.async_opts.async_ctx)
|
|
self.result = AsyncResult(_result=self.function(*self.args, **self.kwargs))
|
|
except Exception as ex:
|
|
self.result = AsyncResult(_result=ex)
|
|
finally:
|
|
self.finished = True
|
|
# Acquire the condition lock before notifying
|
|
with self.condition:
|
|
self.condition.notify_all() # Notify waiting threads that this thread is done
|
|
|
|
|
|
@dataclass
|
|
class AsyncFuture(Generic[R]):
|
|
_tid: str
|
|
_runtime: "AsyncRuntime"
|
|
|
|
def wait(self) -> AsyncResult[R]:
|
|
"""
|
|
Wait for the task to finish.
|
|
"""
|
|
if self._tid not in self._runtime.tasks:
|
|
msg = f"No task with the name '{self._tid}' exists."
|
|
raise ClanError(msg)
|
|
|
|
thread = self._runtime.tasks[self._tid]
|
|
thread.join()
|
|
result = self.get_result()
|
|
if result is None:
|
|
msg = f"Task '{self._tid}' unexpectedly returned None."
|
|
raise ClanError(msg)
|
|
return result
|
|
|
|
def get_result(self) -> AsyncResult[R] | None:
|
|
"""
|
|
Retrieve the result of a finished task and remove it from the task list.
|
|
"""
|
|
if self._tid not in self._runtime.tasks:
|
|
msg = f"No task with the name '{self._tid}' exists."
|
|
raise ClanError(msg)
|
|
|
|
thread = self._runtime.tasks[self._tid]
|
|
|
|
if not thread.finished:
|
|
return None
|
|
|
|
# Remove the task after retrieving the result
|
|
result = thread.result
|
|
del self._runtime.tasks[self._tid]
|
|
|
|
if result is None:
|
|
msg = f"The result for task '{self._tid}' is unexpectedly None."
|
|
raise ClanError(msg)
|
|
|
|
return result
|
|
|
|
|
|
@dataclass
|
|
class AsyncRuntime:
|
|
tasks: dict[str, AsyncThread[Any, Any]] = field(default_factory=dict)
|
|
condition: threading.Condition = field(default_factory=threading.Condition)
|
|
|
|
def async_run(
|
|
self,
|
|
opts: AsyncOpts | None,
|
|
function: Callable[P, R],
|
|
*args: P.args,
|
|
**kwargs: P.kwargs,
|
|
) -> AsyncFuture[R]:
|
|
"""
|
|
Run the given function asynchronously in a thread with a specific name and arguments.
|
|
The function's static typing is preserved.
|
|
"""
|
|
if opts is None:
|
|
opts = AsyncOpts()
|
|
|
|
if opts.tid is None:
|
|
opts.tid = uuid.uuid4().hex
|
|
|
|
if opts.tid in self.tasks:
|
|
msg = f"A task with the name '{opts.tid}' is already running."
|
|
raise ClanError(msg)
|
|
|
|
# Create and start the new AsyncThread
|
|
thread = AsyncThread(opts, self.condition, function, *args, **kwargs)
|
|
self.tasks[opts.tid] = thread
|
|
thread.start()
|
|
return AsyncFuture(opts.tid, self)
|
|
|
|
def join_all(self) -> None:
|
|
"""
|
|
Wait for all tasks to finish
|
|
"""
|
|
with self.condition:
|
|
while any(
|
|
not task.finished for task in self.tasks.values()
|
|
): # Check if any tasks are still running
|
|
self.condition.wait() # Wait until a thread signals completion
|
|
|
|
def check_all(self) -> None:
|
|
"""
|
|
Check if there where any errors
|
|
"""
|
|
err_count = 0
|
|
|
|
for name, task in self.tasks.items():
|
|
if task.finished and task.async_opts.check:
|
|
assert task.result is not None
|
|
error = task.result.error
|
|
if log.isEnabledFor(logging.DEBUG):
|
|
log.error(
|
|
f"failed with error: {error}",
|
|
extra={"command_prefix": name},
|
|
exc_info=error,
|
|
)
|
|
else:
|
|
log.error(
|
|
f"failed with error: {error}", extra={"command_prefix": name}
|
|
)
|
|
err_count += 1
|
|
|
|
if err_count > 0:
|
|
msg = f"{err_count} hosts failed with an error. Check the logs above"
|
|
raise ClanError(msg)
|
|
|
|
def __enter__(self) -> "AsyncRuntime":
|
|
"""
|
|
Enter the runtime context related to this object.
|
|
"""
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_value: BaseException | None,
|
|
traceback: types.TracebackType | None,
|
|
) -> None:
|
|
"""
|
|
Exit the runtime context related to this object.
|
|
Sets async_ctx.cancel to True to signal cancellation.
|
|
"""
|
|
for name, task in self.tasks.items():
|
|
if not task.finished:
|
|
task.async_opts.async_ctx.cancel = True
|
|
log.debug(f"Canceling task {name}")
|
|
|
|
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
runtime = AsyncRuntime()
|
|
|
|
def add(a: int, b: int) -> int:
|
|
return a + b
|
|
|
|
def concatenate(a: str, b: str) -> str:
|
|
time.sleep(1)
|
|
msg = "Hello World"
|
|
raise ClanError(msg)
|
|
|
|
with runtime:
|
|
p1 = runtime.async_run(None, add, 1, 2)
|
|
p2 = runtime.async_run(None, concatenate, "Hello ", "World")
|
|
|
|
add_result = p1.wait()
|
|
print(add_result.result) # Output: 3
|
|
concat_result = p2.wait()
|
|
print(concat_result.error) # Output: Hello World
|