From b23d2b65e1b51797127419c31632b48d1b39fec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Tue, 3 Oct 2023 17:48:56 +0200 Subject: [PATCH] task_manager: return task directly instead of uuid --- pkgs/clan-cli/clan_cli/task_manager.py | 17 +++++++++-------- pkgs/clan-cli/clan_cli/vms/create.py | 9 ++++----- pkgs/clan-cli/clan_cli/webui/routers/vms.py | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/task_manager.py b/pkgs/clan-cli/clan_cli/task_manager.py index f73ff61c3..30ac9411b 100644 --- a/pkgs/clan-cli/clan_cli/task_manager.py +++ b/pkgs/clan-cli/clan_cli/task_manager.py @@ -5,7 +5,7 @@ import select import shlex import subprocess import threading -from typing import Any, Iterator +from typing import Any, Iterator, Type, TypeVar from uuid import UUID, uuid4 @@ -156,14 +156,15 @@ def get_task(uuid: UUID) -> BaseTask: return POOL[uuid] -def register_task(task: type, *args: Any) -> UUID: +T = TypeVar("T", bound="BaseTask") + + +def create_task(task_type: Type[T], *args: Any) -> T: global POOL - if not issubclass(task, BaseTask): - raise TypeError("task must be a subclass of BaseTask") uuid = uuid4() - inst_task = task(uuid, *args) - POOL[uuid] = inst_task - inst_task.start() - return uuid + task = task_type(uuid, *args) + POOL[uuid] = task + task.start() + return task diff --git a/pkgs/clan-cli/clan_cli/vms/create.py b/pkgs/clan-cli/clan_cli/vms/create.py index b3eb51d4e..dd4764f23 100644 --- a/pkgs/clan-cli/clan_cli/vms/create.py +++ b/pkgs/clan-cli/clan_cli/vms/create.py @@ -7,7 +7,7 @@ from uuid import UUID from ..dirs import get_clan_flake_toplevel from ..nix import nix_build, nix_shell -from ..task_manager import BaseTask, CmdState, get_task, register_task +from ..task_manager import BaseTask, CmdState, create_task from .inspect import VmConfig @@ -104,8 +104,8 @@ class BuildVmTask(BaseTask): ) -def create_vm(vm: VmConfig) -> UUID: - return register_task(BuildVmTask, vm) +def create_vm(vm: VmConfig) -> BuildVmTask: + return create_task(BuildVmTask, vm) def create_command(args: argparse.Namespace) -> None: @@ -118,8 +118,7 @@ def create_command(args: argparse.Namespace) -> None: memory_size=0, ) - uuid = create_vm(vm) - task = get_task(uuid) + task = create_vm(vm) for line in task.logs_iter(): print(line, end="") diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index a56e9a17d..340b4c738 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -54,5 +54,5 @@ async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse: status_code=status.HTTP_400_BAD_REQUEST, detail=f"Provided attribute '{vm.flake_attr}' does not exist.", ) - uuid = create.create_vm(vm) - return VmCreateResponse(uuid=str(uuid)) + task = create.create_vm(vm) + return VmCreateResponse(uuid=str(task.uuid))