task_manager: return task directly instead of uuid
This commit is contained in:
@@ -5,7 +5,7 @@ import select
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Any, Iterator
|
||||
from typing import Any, Iterator, Type, TypeVar
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
@@ -156,14 +156,15 @@ def get_task(uuid: UUID) -> BaseTask:
|
||||
return POOL[uuid]
|
||||
|
||||
|
||||
def register_task(task: type, *args: Any) -> UUID:
|
||||
T = TypeVar("T", bound="BaseTask")
|
||||
|
||||
|
||||
def create_task(task_type: Type[T], *args: Any) -> T:
|
||||
global POOL
|
||||
if not issubclass(task, BaseTask):
|
||||
raise TypeError("task must be a subclass of BaseTask")
|
||||
|
||||
uuid = uuid4()
|
||||
|
||||
inst_task = task(uuid, *args)
|
||||
POOL[uuid] = inst_task
|
||||
inst_task.start()
|
||||
return uuid
|
||||
task = task_type(uuid, *args)
|
||||
POOL[uuid] = task
|
||||
task.start()
|
||||
return task
|
||||
|
||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
||||
|
||||
from ..dirs import get_clan_flake_toplevel
|
||||
from ..nix import nix_build, nix_shell
|
||||
from ..task_manager import BaseTask, CmdState, get_task, register_task
|
||||
from ..task_manager import BaseTask, CmdState, create_task
|
||||
from .inspect import VmConfig
|
||||
|
||||
|
||||
@@ -104,8 +104,8 @@ class BuildVmTask(BaseTask):
|
||||
)
|
||||
|
||||
|
||||
def create_vm(vm: VmConfig) -> UUID:
|
||||
return register_task(BuildVmTask, vm)
|
||||
def create_vm(vm: VmConfig) -> BuildVmTask:
|
||||
return create_task(BuildVmTask, vm)
|
||||
|
||||
|
||||
def create_command(args: argparse.Namespace) -> None:
|
||||
@@ -118,8 +118,7 @@ def create_command(args: argparse.Namespace) -> None:
|
||||
memory_size=0,
|
||||
)
|
||||
|
||||
uuid = create_vm(vm)
|
||||
task = get_task(uuid)
|
||||
task = create_vm(vm)
|
||||
for line in task.logs_iter():
|
||||
print(line, end="")
|
||||
|
||||
|
||||
@@ -54,5 +54,5 @@ async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse:
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
|
||||
)
|
||||
uuid = create.create_vm(vm)
|
||||
return VmCreateResponse(uuid=str(uuid))
|
||||
task = create.create_vm(vm)
|
||||
return VmCreateResponse(uuid=str(task.uuid))
|
||||
|
||||
Reference in New Issue
Block a user