diff --git a/pkgs/clan-cli/clan_cli/async_run.py b/pkgs/clan-cli/clan_cli/async_run.py index 1a31668ff..30582aa54 100644 --- a/pkgs/clan-cli/clan_cli/async_run.py +++ b/pkgs/clan-cli/clan_cli/async_run.py @@ -35,6 +35,7 @@ log = logging.getLogger(__name__) # Define generics for return type and call signature R = TypeVar("R") # Return type of the callable P = ParamSpec("P") # Parameters of the callable +Q = TypeVar("Q") # Data type for the async_opts.data field @dataclass @@ -200,6 +201,15 @@ class AsyncFuture(Generic[R]): return result +@dataclass +class AsyncFutureRef(AsyncFuture[R], Generic[R, Q]): + ref: Q | None + + +class AsyncOptsRef(AsyncOpts, Generic[Q]): + ref: Q | None = None + + @dataclass class AsyncRuntime: tasks: dict[str, AsyncThread[Any, Any]] = field(default_factory=dict) @@ -232,6 +242,21 @@ class AsyncRuntime: thread.start() return AsyncFuture(opts.tid, self) + def async_run_ref( + self, + ref: Q, + opts: AsyncOpts | None, + function: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> AsyncFutureRef[R, Q]: + """ + The same as async_run, but with an additional reference to an object. + This is useful to keep track of the origin of the task. + """ + future = self.async_run(opts, function, *args, **kwargs) + return AsyncFutureRef(future._tid, self, ref) # noqa: SLF001 + def join_all(self) -> None: """ Wait for all tasks to finish