Merge pull request 'fix task manager race conditions' (#404) from lassulus-taskmanager into main
This commit is contained in:
@@ -21,7 +21,8 @@ class Command:
|
|||||||
self._output: queue.SimpleQueue = queue.SimpleQueue()
|
self._output: queue.SimpleQueue = queue.SimpleQueue()
|
||||||
self.returncode: int | None = None
|
self.returncode: int | None = None
|
||||||
self.done: bool = False
|
self.done: bool = False
|
||||||
self.lines: list[str] = []
|
self.stdout: list[str] = []
|
||||||
|
self.stderr: list[str] = []
|
||||||
|
|
||||||
def close_queue(self) -> None:
|
def close_queue(self) -> None:
|
||||||
if self.p is not None:
|
if self.p is not None:
|
||||||
@@ -31,7 +32,6 @@ class Command:
|
|||||||
|
|
||||||
def run(self, cmd: list[str]) -> None:
|
def run(self, cmd: list[str]) -> None:
|
||||||
self.running = True
|
self.running = True
|
||||||
try:
|
|
||||||
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
||||||
self.p = subprocess.Popen(
|
self.p = subprocess.Popen(
|
||||||
cmd,
|
cmd,
|
||||||
@@ -50,7 +50,10 @@ class Command:
|
|||||||
try:
|
try:
|
||||||
for line in fd:
|
for line in fd:
|
||||||
self.log.debug("stdout: %s", line)
|
self.log.debug("stdout: %s", line)
|
||||||
self.lines.append(line)
|
if fd == self.p.stderr:
|
||||||
|
self.stderr.append(line)
|
||||||
|
else:
|
||||||
|
self.stdout.append(line)
|
||||||
self._output.put(line)
|
self._output.put(line)
|
||||||
except BlockingIOError:
|
except BlockingIOError:
|
||||||
continue
|
continue
|
||||||
@@ -59,8 +62,6 @@ class Command:
|
|||||||
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
|
raise ClanError(f"Failed to run command: {shlex.join(cmd)}")
|
||||||
|
|
||||||
self.log.debug("Successfully ran command")
|
self.log.debug("Successfully ran command")
|
||||||
finally:
|
|
||||||
self.close_queue()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
class TaskStatus(str, Enum):
|
||||||
@@ -71,7 +72,7 @@ class TaskStatus(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class BaseTask:
|
class BaseTask:
|
||||||
def __init__(self, uuid: UUID) -> None:
|
def __init__(self, uuid: UUID, num_cmds: int) -> None:
|
||||||
# constructor
|
# constructor
|
||||||
self.uuid: UUID = uuid
|
self.uuid: UUID = uuid
|
||||||
self.log = logging.getLogger(__name__)
|
self.log = logging.getLogger(__name__)
|
||||||
@@ -80,6 +81,10 @@ class BaseTask:
|
|||||||
self.logs_lock = threading.Lock()
|
self.logs_lock = threading.Lock()
|
||||||
self.error: Exception | None = None
|
self.error: Exception | None = None
|
||||||
|
|
||||||
|
for _ in range(num_cmds):
|
||||||
|
cmd = Command(self.log)
|
||||||
|
self.procs.append(cmd)
|
||||||
|
|
||||||
def _run(self) -> None:
|
def _run(self) -> None:
|
||||||
self.status = TaskStatus.RUNNING
|
self.status = TaskStatus.RUNNING
|
||||||
try:
|
try:
|
||||||
@@ -87,13 +92,14 @@ class BaseTask:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# FIXME: fix exception handling here
|
# FIXME: fix exception handling here
|
||||||
traceback.print_exception(*sys.exc_info())
|
traceback.print_exception(*sys.exc_info())
|
||||||
for proc in self.procs:
|
|
||||||
proc.close_queue()
|
|
||||||
self.error = e
|
self.error = e
|
||||||
self.log.exception(e)
|
self.log.exception(e)
|
||||||
self.status = TaskStatus.FAILED
|
self.status = TaskStatus.FAILED
|
||||||
else:
|
else:
|
||||||
self.status = TaskStatus.FINISHED
|
self.status = TaskStatus.FINISHED
|
||||||
|
finally:
|
||||||
|
for proc in self.procs:
|
||||||
|
proc.close_queue()
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -106,19 +112,16 @@ class BaseTask:
|
|||||||
return
|
return
|
||||||
# process has finished
|
# process has finished
|
||||||
if proc.done:
|
if proc.done:
|
||||||
for line in proc.lines:
|
for line in proc.stdout:
|
||||||
|
yield line
|
||||||
|
for line in proc.stderr:
|
||||||
yield line
|
yield line
|
||||||
else:
|
else:
|
||||||
while line := proc._output.get():
|
while line := proc._output.get():
|
||||||
yield line
|
yield line
|
||||||
|
|
||||||
def register_commands(self, num_cmds: int) -> Iterator[Command]:
|
def commands(self) -> Iterator[Command]:
|
||||||
for _ in range(num_cmds):
|
yield from self.procs
|
||||||
cmd = Command(self.log)
|
|
||||||
self.procs.append(cmd)
|
|
||||||
|
|
||||||
for cmd in self.procs:
|
|
||||||
yield cmd
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: We need to test concurrency
|
# TODO: We need to test concurrency
|
||||||
@@ -157,6 +160,6 @@ def create_task(task_type: Type[T], *args: Any) -> T:
|
|||||||
uuid = uuid4()
|
uuid = uuid4()
|
||||||
|
|
||||||
task = task_type(uuid, *args)
|
task = task_type(uuid, *args)
|
||||||
threading.Thread(target=task._run).start()
|
|
||||||
POOL[uuid] = task
|
POOL[uuid] = task
|
||||||
|
threading.Thread(target=task._run).start()
|
||||||
return task
|
return task
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from .inspect import VmConfig, inspect_vm
|
|||||||
|
|
||||||
class BuildVmTask(BaseTask):
|
class BuildVmTask(BaseTask):
|
||||||
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
|
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
|
||||||
super().__init__(uuid)
|
super().__init__(uuid, num_cmds=4)
|
||||||
self.vm = vm
|
self.vm = vm
|
||||||
|
|
||||||
def get_vm_create_info(self, cmds: Iterator[Command]) -> dict:
|
def get_vm_create_info(self, cmds: Iterator[Command]) -> dict:
|
||||||
@@ -30,13 +30,13 @@ class BuildVmTask(BaseTask):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
vm_json = "".join(cmd.lines)
|
vm_json = "".join(cmd.stdout)
|
||||||
self.log.debug(f"VM JSON path: {vm_json}")
|
self.log.debug(f"VM JSON path: {vm_json}")
|
||||||
with open(vm_json.strip()) as f:
|
with open(vm_json.strip()) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
cmds = self.register_commands(4)
|
cmds = self.commands()
|
||||||
|
|
||||||
machine = self.vm.flake_attr
|
machine = self.vm.flake_attr
|
||||||
self.log.debug(f"Creating VM for {machine}")
|
self.log.debug(f"Creating VM for {machine}")
|
||||||
|
|||||||
@@ -28,7 +28,8 @@ async def inspect_vm(
|
|||||||
async def get_vm_status(uuid: UUID) -> VmStatusResponse:
|
async def get_vm_status(uuid: UUID) -> VmStatusResponse:
|
||||||
task = get_task(uuid)
|
task = get_task(uuid)
|
||||||
log.debug(msg=f"error: {task.error}, task.status: {task.status}")
|
log.debug(msg=f"error: {task.error}, task.status: {task.status}")
|
||||||
return VmStatusResponse(status=task.status, error=str(task.error))
|
error = str(task.error) if task.error is not None else None
|
||||||
|
return VmStatusResponse(status=task.status, error=error)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/vms/{uuid}/logs")
|
@router.get("/api/vms/{uuid}/logs")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
clan.virtualisation.graphics = false;
|
clan.virtualisation.graphics = false;
|
||||||
|
|
||||||
clan.networking.zerotier.controller.enable = true;
|
clan.networking.zerotier.controller.enable = true;
|
||||||
|
networking.useDHCP = false;
|
||||||
|
|
||||||
systemd.services.shutdown-after-boot = {
|
systemd.services.shutdown-after-boot = {
|
||||||
enable = true;
|
enable = true;
|
||||||
|
|||||||
Reference in New Issue
Block a user