From c96b339d94297a8c61c42ae410ce9b22cf877af3 Mon Sep 17 00:00:00 2001 From: Qubasa Date: Tue, 26 Sep 2023 19:36:01 +0200 Subject: [PATCH] Extracted threadpool to task_manager.py --- pkgs/clan-cli/clan_cli/webui/app.py | 6 +- pkgs/clan-cli/clan_cli/webui/routers/vms.py | 207 ++++++------------- pkgs/clan-cli/clan_cli/webui/task_manager.py | 114 ++++++++++ pkgs/clan-cli/tests/test_vms_api.py | 8 +- 4 files changed, 185 insertions(+), 150 deletions(-) create mode 100644 pkgs/clan-cli/clan_cli/webui/task_manager.py diff --git a/pkgs/clan-cli/clan_cli/webui/app.py b/pkgs/clan-cli/clan_cli/webui/app.py index bd586789d..b392c2118 100644 --- a/pkgs/clan-cli/clan_cli/webui/app.py +++ b/pkgs/clan-cli/clan_cli/webui/app.py @@ -28,8 +28,11 @@ def setup_app() -> FastAPI: app.include_router(flake.router) app.include_router(health.router) app.include_router(machines.router) - app.include_router(root.router) app.include_router(vms.router) + + # Needs to be last in register. Because of wildcard route + app.include_router(root.router) + app.add_exception_handler(vms.NixBuildException, vms.nix_build_exception_handler) app.mount("/static", StaticFiles(directory=asset_path()), name="static") @@ -37,6 +40,7 @@ def setup_app() -> FastAPI: for route in app.routes: if isinstance(route, APIRoute): route.operation_id = route.name # in this case, 'read_items' + log.debug(f"Registered route: {route}") return app diff --git a/pkgs/clan-cli/clan_cli/webui/routers/vms.py b/pkgs/clan-cli/clan_cli/webui/routers/vms.py index 47566fadd..ba6881afa 100644 --- a/pkgs/clan-cli/clan_cli/webui/routers/vms.py +++ b/pkgs/clan-cli/clan_cli/webui/routers/vms.py @@ -1,11 +1,10 @@ import asyncio import json import logging -import os -import select -import queue import shlex +import io import subprocess +import pipes import threading from typing import Annotated, AsyncIterator from uuid import UUID, uuid4 @@ -23,13 +22,10 @@ from fastapi.responses import JSONResponse, StreamingResponse from ...nix import nix_build, nix_eval from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse +from ..task_manager import BaseTask, get_task, register_task -# Logging setup log = logging.getLogger(__name__) - router = APIRouter() -app = FastAPI() - def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]: @@ -48,35 +44,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]: ) -@router.post("/api/vms/inspect") -async def inspect_vm( - flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()] -) -> VmInspectResponse: - cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url) - proc = await asyncio.create_subprocess_exec( - cmd[0], - *cmd[1:], - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await proc.communicate() - - if proc.returncode != 0: - raise NixBuildException( - f""" -Failed to evaluate vm from '{flake_url}#{flake_attr}'. -command: {shlex.join(cmd)} -exit code: {proc.returncode} -command output: -{stderr.decode("utf-8")} -""" - ) - data = json.loads(stdout) - return VmInspectResponse( - config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data) - ) - - class NixBuildException(HTTPException): def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]): self.uuid = uuid @@ -93,146 +60,97 @@ class NixBuildException(HTTPException): ) -class ProcessState: - def __init__(self, proc: subprocess.Popen): - self.proc: subprocess.Process = proc - self.stdout: list[str] = [] - self.stderr: list[str] = [] - self.returncode: int | None = None - self.done: bool = False +class BuildVmTask(BaseTask): + def __init__(self, uuid: UUID, vm: VmConfig) -> None: + super().__init__(uuid) + self.vm = vm -class BuildVM(threading.Thread): - def __init__(self, vm: VmConfig, uuid: UUID): - # calling parent class constructor - threading.Thread.__init__(self) - - # constructor - self.vm: VmConfig = vm - self.uuid: UUID = uuid - self.log = logging.getLogger(__name__) - self.procs: list[ProcessState] = [] - self.failed: bool = False - self.finished: bool = False - - def run(self): + def run(self) -> None: try: - self.log.debug(f"BuildVM with uuid {self.uuid} started") cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url) proc = self.run_cmd(cmd) - out = proc.stdout - self.log.debug(f"out: {out}") + self.log.debug(f"stdout: {proc.stdout}") - vm_path = f"{''.join(out)}/bin/run-nixos-vm" + vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm" self.log.debug(f"vm_path: {vm_path}") - self.run_cmd(vm_path) + self.run_cmd(vm_path) self.finished = True except Exception as e: self.failed = True self.finished = True log.exception(e) - def run_cmd(self, cmd: list[str]) -> ProcessState: - cwd = os.getcwd() - log.debug(f"Working directory: {cwd}") - log.debug(f"Running command: {shlex.join(cmd)}") - process = subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - encoding="utf-8", - cwd=cwd, - ) - state = ProcessState(process) - self.procs.append(state) - - while process.poll() is None: - # Check if stderr is ready to be read from - rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0) - if process.stderr in rlist: - line = process.stderr.readline() - state.stderr.append(line) - if process.stdout in rlist: - line = process.stdout.readline() - state.stdout.append(line) - - state.returncode = process.returncode - state.done = True - - if process.returncode != 0: - raise NixBuildException( - self.uuid, f"Failed to run command: {shlex.join(cmd)}" - ) - - log.debug("Successfully ran command") - return state - - -class VmTaskPool: - def __init__(self) -> None: - self.lock: threading.RLock = threading.RLock() - self.pool: dict[UUID, BuildVM] = {} - - def __getitem__(self, uuid: str | UUID) -> BuildVM: - with self.lock: - if type(uuid) is UUID: - return self.pool[uuid] - else: - uuid = UUID(uuid) - return self.pool[uuid] - - def __setitem__(self, uuid: UUID, vm: BuildVM) -> None: - with self.lock: - if uuid in self.pool: - raise KeyError(f"VM with uuid {uuid} already exists") - if type(uuid) is not UUID: - raise TypeError("uuid must be of type UUID") - self.pool[uuid] = vm - - -POOL: VmTaskPool = VmTaskPool() - def nix_build_exception_handler( request: Request, exc: NixBuildException ) -> JSONResponse: log.error("NixBuildException: %s", exc) - # del POOL[exc.uuid] return JSONResponse( status_code=exc.status_code, content=jsonable_encoder(dict(detail=exc.detail)), ) +################################## +# # +# ======== VM ROUTES ======== # +# # +################################## +@router.post("/api/vms/inspect") +async def inspect_vm( + flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()] +) -> VmInspectResponse: + cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url) + proc = await asyncio.create_subprocess_exec( + cmd[0], + *cmd[1:], + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + + if proc.returncode != 0: + raise NixBuildException( + "" + f""" +Failed to evaluate vm from '{flake_url}#{flake_attr}'. +command: {shlex.join(cmd)} +exit code: {proc.returncode} +command output: +{stderr.decode("utf-8")} +""" + ) + data = json.loads(stdout) + return VmInspectResponse( + config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data) + ) + + @router.get("/api/vms/{uuid}/status") async def get_status(uuid: str) -> VmStatusResponse: - global POOL - handle = POOL[uuid] - - if handle.process.poll() is None: - return VmStatusResponse(running=True, status=0) - else: - return VmStatusResponse(running=False, status=handle.process.returncode) - + task = get_task(uuid) + return VmStatusResponse(running=not task.finished, status=0) @router.get("/api/vms/{uuid}/logs") async def get_logs(uuid: str) -> StreamingResponse: async def stream_logs() -> AsyncIterator[str]: - global POOL - handle = POOL[uuid] - for proc in handle.procs.values(): - while True: - if proc.stdout.empty() and proc.stderr.empty() and not proc.done: - await asyncio.sleep(0.1) - continue - if proc.stdout.empty() and proc.stderr.empty() and proc.done: - break + task = get_task(uuid) + + for proc in task.procs: + if proc.done: + for line in proc.stderr: + yield line for line in proc.stdout: yield line - for line in proc.stderr: + else: + while True: + if proc.output_pipe.empty() and proc.done: + break + line = await proc.output_pipe.get() yield line return StreamingResponse( @@ -240,14 +158,9 @@ async def get_logs(uuid: str) -> StreamingResponse: media_type="text/plain", ) - @router.post("/api/vms/create") async def create_vm( vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks ) -> VmCreateResponse: - global POOL - uuid = uuid4() - handle = BuildVM(vm, uuid) - handle.start() - POOL[uuid] = handle + uuid = register_task(BuildVmTask, vm) return VmCreateResponse(uuid=str(uuid)) diff --git a/pkgs/clan-cli/clan_cli/webui/task_manager.py b/pkgs/clan-cli/clan_cli/webui/task_manager.py new file mode 100644 index 000000000..3ee619e01 --- /dev/null +++ b/pkgs/clan-cli/clan_cli/webui/task_manager.py @@ -0,0 +1,114 @@ +import logging +import os +import queue +import select +import shlex +import subprocess +import threading +from uuid import UUID, uuid4 + +class CmdState: + def __init__(self, proc: subprocess.Popen) -> None: + self.proc: subprocess.Process = proc + self.stdout: list[str] = [] + self.stderr: list[str] = [] + self.output_pipe: asyncio.Queue = asyncio.Queue() + self.returncode: int | None = None + self.done: bool = False + +class BaseTask(threading.Thread): + def __init__(self, uuid: UUID) -> None: + # calling parent class constructor + threading.Thread.__init__(self) + + # constructor + self.uuid: UUID = uuid + self.log = logging.getLogger(__name__) + self.procs: list[CmdState] = [] + self.failed: bool = False + self.finished: bool = False + + def run(self) -> None: + self.finished = True + + def run_cmd(self, cmd: list[str]) -> CmdState: + cwd = os.getcwd() + self.log.debug(f"Working directory: {cwd}") + self.log.debug(f"Running command: {shlex.join(cmd)}") + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + cwd=cwd, + ) + state = CmdState(process) + self.procs.append(state) + + while process.poll() is None: + # Check if stderr is ready to be read from + rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0) + if process.stderr in rlist: + line = process.stderr.readline() + if line != "": + state.stderr.append(line.strip('\n')) + state.output_pipe.put_nowait(line) + if process.stdout in rlist: + line = process.stdout.readline() + if line != "": + state.stdout.append(line.strip('\n')) + state.output_pipe.put_nowait(line) + + state.returncode = process.returncode + state.done = True + + if process.returncode != 0: + raise RuntimeError( + f"Failed to run command: {shlex.join(cmd)}" + ) + + self.log.debug("Successfully ran command") + return state + + +class TaskPool: + def __init__(self) -> None: + self.lock: threading.RLock = threading.RLock() + self.pool: dict[UUID, BaseTask] = {} + + def __getitem__(self, uuid: str | UUID) -> BaseTask: + with self.lock: + if type(uuid) is UUID: + return self.pool[uuid] + else: + uuid = UUID(uuid) + return self.pool[uuid] + + + def __setitem__(self, uuid: UUID, vm: BaseTask) -> None: + with self.lock: + if uuid in self.pool: + raise KeyError(f"VM with uuid {uuid} already exists") + if type(uuid) is not UUID: + raise TypeError("uuid must be of type UUID") + self.pool[uuid] = vm + + +POOL: TaskPool = TaskPool() + + +def get_task(uuid: UUID) -> BaseTask: + global POOL + return POOL[uuid] + + +def register_task(task: BaseTask, *kwargs) -> UUID: + global POOL + if not issubclass(task, BaseTask): + raise TypeError("task must be a subclass of BaseTask") + + uuid = uuid4() + inst_task = task(uuid, *kwargs) + POOL[uuid] = inst_task + inst_task.start() + return uuid diff --git a/pkgs/clan-cli/tests/test_vms_api.py b/pkgs/clan-cli/tests/test_vms_api.py index fdc1a09ae..5aa6d917a 100644 --- a/pkgs/clan-cli/tests/test_vms_api.py +++ b/pkgs/clan-cli/tests/test_vms_api.py @@ -31,14 +31,18 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None: graphics=True, ), ) - assert response.status_code == 200, "Failed to inspect vm" + assert response.status_code == 200, "Failed to create vm" uuid = response.json()["uuid"] assert len(uuid) == 36 assert uuid.count("-") == 4 response = api.get(f"/api/vms/{uuid}/status") + assert response.status_code == 200, "Failed to get vm status" + + response = api.get(f"/api/vms/{uuid}/logs") + print("=========LOGS==========") for line in response.stream: print(line) - assert response.status_code == 200, "Failed to get vm status" \ No newline at end of file + assert response.status_code == 200, "Failed to get vm logs" \ No newline at end of file