Merge pull request 'task_manager: return task directly instead of uuid' (#390) from Mic92-HEAD into main

This commit is contained in:
clan-bot
2023-10-03 15:53:25 +00:00
3 changed files with 15 additions and 15 deletions

View File

@@ -5,7 +5,7 @@ import select
import shlex import shlex
import subprocess import subprocess
import threading import threading
from typing import Any, Iterator from typing import Any, Iterator, Type, TypeVar
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@@ -156,14 +156,15 @@ def get_task(uuid: UUID) -> BaseTask:
return POOL[uuid] return POOL[uuid]
def register_task(task: type, *args: Any) -> UUID: T = TypeVar("T", bound="BaseTask")
def create_task(task_type: Type[T], *args: Any) -> T:
global POOL global POOL
if not issubclass(task, BaseTask):
raise TypeError("task must be a subclass of BaseTask")
uuid = uuid4() uuid = uuid4()
inst_task = task(uuid, *args) task = task_type(uuid, *args)
POOL[uuid] = inst_task POOL[uuid] = task
inst_task.start() task.start()
return uuid return task

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from ..dirs import get_clan_flake_toplevel from ..dirs import get_clan_flake_toplevel
from ..nix import nix_build, nix_shell from ..nix import nix_build, nix_shell
from ..task_manager import BaseTask, CmdState, get_task, register_task from ..task_manager import BaseTask, CmdState, create_task
from .inspect import VmConfig from .inspect import VmConfig
@@ -104,8 +104,8 @@ class BuildVmTask(BaseTask):
) )
def create_vm(vm: VmConfig) -> UUID: def create_vm(vm: VmConfig) -> BuildVmTask:
return register_task(BuildVmTask, vm) return create_task(BuildVmTask, vm)
def create_command(args: argparse.Namespace) -> None: def create_command(args: argparse.Namespace) -> None:
@@ -118,8 +118,7 @@ def create_command(args: argparse.Namespace) -> None:
memory_size=0, memory_size=0,
) )
uuid = create_vm(vm) task = create_vm(vm)
task = get_task(uuid)
for line in task.logs_iter(): for line in task.logs_iter():
print(line, end="") print(line, end="")

View File

@@ -54,5 +54,5 @@ async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse:
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Provided attribute '{vm.flake_attr}' does not exist.", detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
) )
uuid = create.create_vm(vm) task = create.create_vm(vm)
return VmCreateResponse(uuid=str(uuid)) return VmCreateResponse(uuid=str(task.uuid))