task_manager: return task directly instead of uuid
This commit is contained in:
@@ -5,7 +5,7 @@ import select
|
|||||||
import shlex
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Iterator
|
from typing import Any, Iterator, Type, TypeVar
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
|
||||||
@@ -156,14 +156,15 @@ def get_task(uuid: UUID) -> BaseTask:
|
|||||||
return POOL[uuid]
|
return POOL[uuid]
|
||||||
|
|
||||||
|
|
||||||
def register_task(task: type, *args: Any) -> UUID:
|
T = TypeVar("T", bound="BaseTask")
|
||||||
|
|
||||||
|
|
||||||
|
def create_task(task_type: Type[T], *args: Any) -> T:
|
||||||
global POOL
|
global POOL
|
||||||
if not issubclass(task, BaseTask):
|
|
||||||
raise TypeError("task must be a subclass of BaseTask")
|
|
||||||
|
|
||||||
uuid = uuid4()
|
uuid = uuid4()
|
||||||
|
|
||||||
inst_task = task(uuid, *args)
|
task = task_type(uuid, *args)
|
||||||
POOL[uuid] = inst_task
|
POOL[uuid] = task
|
||||||
inst_task.start()
|
task.start()
|
||||||
return uuid
|
return task
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from ..dirs import get_clan_flake_toplevel
|
from ..dirs import get_clan_flake_toplevel
|
||||||
from ..nix import nix_build, nix_shell
|
from ..nix import nix_build, nix_shell
|
||||||
from ..task_manager import BaseTask, CmdState, get_task, register_task
|
from ..task_manager import BaseTask, CmdState, create_task
|
||||||
from .inspect import VmConfig
|
from .inspect import VmConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -104,8 +104,8 @@ class BuildVmTask(BaseTask):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_vm(vm: VmConfig) -> UUID:
|
def create_vm(vm: VmConfig) -> BuildVmTask:
|
||||||
return register_task(BuildVmTask, vm)
|
return create_task(BuildVmTask, vm)
|
||||||
|
|
||||||
|
|
||||||
def create_command(args: argparse.Namespace) -> None:
|
def create_command(args: argparse.Namespace) -> None:
|
||||||
@@ -118,8 +118,7 @@ def create_command(args: argparse.Namespace) -> None:
|
|||||||
memory_size=0,
|
memory_size=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
uuid = create_vm(vm)
|
task = create_vm(vm)
|
||||||
task = get_task(uuid)
|
|
||||||
for line in task.logs_iter():
|
for line in task.logs_iter():
|
||||||
print(line, end="")
|
print(line, end="")
|
||||||
|
|
||||||
|
|||||||
@@ -54,5 +54,5 @@ async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse:
|
|||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
|
detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
|
||||||
)
|
)
|
||||||
uuid = create.create_vm(vm)
|
task = create.create_vm(vm)
|
||||||
return VmCreateResponse(uuid=str(uuid))
|
return VmCreateResponse(uuid=str(task.uuid))
|
||||||
|
|||||||
Reference in New Issue
Block a user