Extracted threadpool to task_manager.py
This commit is contained in:
@@ -28,8 +28,11 @@ def setup_app() -> FastAPI:
|
||||
app.include_router(flake.router)
|
||||
app.include_router(health.router)
|
||||
app.include_router(machines.router)
|
||||
app.include_router(root.router)
|
||||
app.include_router(vms.router)
|
||||
|
||||
# Needs to be last in register. Because of wildcard route
|
||||
app.include_router(root.router)
|
||||
|
||||
app.add_exception_handler(vms.NixBuildException, vms.nix_build_exception_handler)
|
||||
|
||||
app.mount("/static", StaticFiles(directory=asset_path()), name="static")
|
||||
@@ -37,6 +40,7 @@ def setup_app() -> FastAPI:
|
||||
for route in app.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
route.operation_id = route.name # in this case, 'read_items'
|
||||
log.debug(f"Registered route: {route}")
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import select
|
||||
import queue
|
||||
import shlex
|
||||
import io
|
||||
import subprocess
|
||||
import pipes
|
||||
import threading
|
||||
from typing import Annotated, AsyncIterator
|
||||
from uuid import UUID, uuid4
|
||||
@@ -23,13 +22,10 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from ...nix import nix_build, nix_eval
|
||||
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
|
||||
from ..task_manager import BaseTask, get_task, register_task
|
||||
|
||||
# Logging setup
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
|
||||
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
||||
@@ -48,35 +44,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/vms/inspect")
|
||||
async def inspect_vm(
|
||||
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
|
||||
) -> VmInspectResponse:
|
||||
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
cmd[0],
|
||||
*cmd[1:],
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise NixBuildException(
|
||||
f"""
|
||||
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
|
||||
command: {shlex.join(cmd)}
|
||||
exit code: {proc.returncode}
|
||||
command output:
|
||||
{stderr.decode("utf-8")}
|
||||
"""
|
||||
)
|
||||
data = json.loads(stdout)
|
||||
return VmInspectResponse(
|
||||
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
|
||||
)
|
||||
|
||||
|
||||
class NixBuildException(HTTPException):
|
||||
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
|
||||
self.uuid = uuid
|
||||
@@ -93,146 +60,97 @@ class NixBuildException(HTTPException):
|
||||
)
|
||||
|
||||
|
||||
class ProcessState:
|
||||
def __init__(self, proc: subprocess.Popen):
|
||||
self.proc: subprocess.Process = proc
|
||||
self.stdout: list[str] = []
|
||||
self.stderr: list[str] = []
|
||||
self.returncode: int | None = None
|
||||
self.done: bool = False
|
||||
class BuildVmTask(BaseTask):
|
||||
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
|
||||
super().__init__(uuid)
|
||||
self.vm = vm
|
||||
|
||||
class BuildVM(threading.Thread):
|
||||
def __init__(self, vm: VmConfig, uuid: UUID):
|
||||
# calling parent class constructor
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
# constructor
|
||||
self.vm: VmConfig = vm
|
||||
self.uuid: UUID = uuid
|
||||
self.log = logging.getLogger(__name__)
|
||||
self.procs: list[ProcessState] = []
|
||||
self.failed: bool = False
|
||||
self.finished: bool = False
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
try:
|
||||
|
||||
self.log.debug(f"BuildVM with uuid {self.uuid} started")
|
||||
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
|
||||
|
||||
proc = self.run_cmd(cmd)
|
||||
out = proc.stdout
|
||||
self.log.debug(f"out: {out}")
|
||||
self.log.debug(f"stdout: {proc.stdout}")
|
||||
|
||||
vm_path = f"{''.join(out)}/bin/run-nixos-vm"
|
||||
vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
|
||||
self.log.debug(f"vm_path: {vm_path}")
|
||||
self.run_cmd(vm_path)
|
||||
|
||||
self.run_cmd(vm_path)
|
||||
self.finished = True
|
||||
except Exception as e:
|
||||
self.failed = True
|
||||
self.finished = True
|
||||
log.exception(e)
|
||||
|
||||
def run_cmd(self, cmd: list[str]) -> ProcessState:
|
||||
cwd = os.getcwd()
|
||||
log.debug(f"Working directory: {cwd}")
|
||||
log.debug(f"Running command: {shlex.join(cmd)}")
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
cwd=cwd,
|
||||
)
|
||||
state = ProcessState(process)
|
||||
self.procs.append(state)
|
||||
|
||||
while process.poll() is None:
|
||||
# Check if stderr is ready to be read from
|
||||
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
|
||||
if process.stderr in rlist:
|
||||
line = process.stderr.readline()
|
||||
state.stderr.append(line)
|
||||
if process.stdout in rlist:
|
||||
line = process.stdout.readline()
|
||||
state.stdout.append(line)
|
||||
|
||||
state.returncode = process.returncode
|
||||
state.done = True
|
||||
|
||||
if process.returncode != 0:
|
||||
raise NixBuildException(
|
||||
self.uuid, f"Failed to run command: {shlex.join(cmd)}"
|
||||
)
|
||||
|
||||
log.debug("Successfully ran command")
|
||||
return state
|
||||
|
||||
|
||||
class VmTaskPool:
|
||||
def __init__(self) -> None:
|
||||
self.lock: threading.RLock = threading.RLock()
|
||||
self.pool: dict[UUID, BuildVM] = {}
|
||||
|
||||
def __getitem__(self, uuid: str | UUID) -> BuildVM:
|
||||
with self.lock:
|
||||
if type(uuid) is UUID:
|
||||
return self.pool[uuid]
|
||||
else:
|
||||
uuid = UUID(uuid)
|
||||
return self.pool[uuid]
|
||||
|
||||
def __setitem__(self, uuid: UUID, vm: BuildVM) -> None:
|
||||
with self.lock:
|
||||
if uuid in self.pool:
|
||||
raise KeyError(f"VM with uuid {uuid} already exists")
|
||||
if type(uuid) is not UUID:
|
||||
raise TypeError("uuid must be of type UUID")
|
||||
self.pool[uuid] = vm
|
||||
|
||||
|
||||
POOL: VmTaskPool = VmTaskPool()
|
||||
|
||||
|
||||
def nix_build_exception_handler(
|
||||
request: Request, exc: NixBuildException
|
||||
) -> JSONResponse:
|
||||
log.error("NixBuildException: %s", exc)
|
||||
# del POOL[exc.uuid]
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=jsonable_encoder(dict(detail=exc.detail)),
|
||||
)
|
||||
|
||||
|
||||
##################################
|
||||
# #
|
||||
# ======== VM ROUTES ======== #
|
||||
# #
|
||||
##################################
|
||||
@router.post("/api/vms/inspect")
|
||||
async def inspect_vm(
|
||||
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
|
||||
) -> VmInspectResponse:
|
||||
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
cmd[0],
|
||||
*cmd[1:],
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise NixBuildException(
|
||||
""
|
||||
f"""
|
||||
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
|
||||
command: {shlex.join(cmd)}
|
||||
exit code: {proc.returncode}
|
||||
command output:
|
||||
{stderr.decode("utf-8")}
|
||||
"""
|
||||
)
|
||||
data = json.loads(stdout)
|
||||
return VmInspectResponse(
|
||||
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/vms/{uuid}/status")
|
||||
async def get_status(uuid: str) -> VmStatusResponse:
|
||||
global POOL
|
||||
handle = POOL[uuid]
|
||||
|
||||
if handle.process.poll() is None:
|
||||
return VmStatusResponse(running=True, status=0)
|
||||
else:
|
||||
return VmStatusResponse(running=False, status=handle.process.returncode)
|
||||
|
||||
task = get_task(uuid)
|
||||
return VmStatusResponse(running=not task.finished, status=0)
|
||||
|
||||
|
||||
@router.get("/api/vms/{uuid}/logs")
|
||||
async def get_logs(uuid: str) -> StreamingResponse:
|
||||
async def stream_logs() -> AsyncIterator[str]:
|
||||
global POOL
|
||||
handle = POOL[uuid]
|
||||
for proc in handle.procs.values():
|
||||
while True:
|
||||
if proc.stdout.empty() and proc.stderr.empty() and not proc.done:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
if proc.stdout.empty() and proc.stderr.empty() and proc.done:
|
||||
break
|
||||
task = get_task(uuid)
|
||||
|
||||
for proc in task.procs:
|
||||
if proc.done:
|
||||
for line in proc.stderr:
|
||||
yield line
|
||||
for line in proc.stdout:
|
||||
yield line
|
||||
for line in proc.stderr:
|
||||
else:
|
||||
while True:
|
||||
if proc.output_pipe.empty() and proc.done:
|
||||
break
|
||||
line = await proc.output_pipe.get()
|
||||
yield line
|
||||
|
||||
return StreamingResponse(
|
||||
@@ -240,14 +158,9 @@ async def get_logs(uuid: str) -> StreamingResponse:
|
||||
media_type="text/plain",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/vms/create")
|
||||
async def create_vm(
|
||||
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
|
||||
) -> VmCreateResponse:
|
||||
global POOL
|
||||
uuid = uuid4()
|
||||
handle = BuildVM(vm, uuid)
|
||||
handle.start()
|
||||
POOL[uuid] = handle
|
||||
uuid = register_task(BuildVmTask, vm)
|
||||
return VmCreateResponse(uuid=str(uuid))
|
||||
|
||||
114
pkgs/clan-cli/clan_cli/webui/task_manager.py
Normal file
114
pkgs/clan-cli/clan_cli/webui/task_manager.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import select
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
class CmdState:
|
||||
def __init__(self, proc: subprocess.Popen) -> None:
|
||||
self.proc: subprocess.Process = proc
|
||||
self.stdout: list[str] = []
|
||||
self.stderr: list[str] = []
|
||||
self.output_pipe: asyncio.Queue = asyncio.Queue()
|
||||
self.returncode: int | None = None
|
||||
self.done: bool = False
|
||||
|
||||
class BaseTask(threading.Thread):
|
||||
def __init__(self, uuid: UUID) -> None:
|
||||
# calling parent class constructor
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
# constructor
|
||||
self.uuid: UUID = uuid
|
||||
self.log = logging.getLogger(__name__)
|
||||
self.procs: list[CmdState] = []
|
||||
self.failed: bool = False
|
||||
self.finished: bool = False
|
||||
|
||||
def run(self) -> None:
|
||||
self.finished = True
|
||||
|
||||
def run_cmd(self, cmd: list[str]) -> CmdState:
|
||||
cwd = os.getcwd()
|
||||
self.log.debug(f"Working directory: {cwd}")
|
||||
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
cwd=cwd,
|
||||
)
|
||||
state = CmdState(process)
|
||||
self.procs.append(state)
|
||||
|
||||
while process.poll() is None:
|
||||
# Check if stderr is ready to be read from
|
||||
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
|
||||
if process.stderr in rlist:
|
||||
line = process.stderr.readline()
|
||||
if line != "":
|
||||
state.stderr.append(line.strip('\n'))
|
||||
state.output_pipe.put_nowait(line)
|
||||
if process.stdout in rlist:
|
||||
line = process.stdout.readline()
|
||||
if line != "":
|
||||
state.stdout.append(line.strip('\n'))
|
||||
state.output_pipe.put_nowait(line)
|
||||
|
||||
state.returncode = process.returncode
|
||||
state.done = True
|
||||
|
||||
if process.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to run command: {shlex.join(cmd)}"
|
||||
)
|
||||
|
||||
self.log.debug("Successfully ran command")
|
||||
return state
|
||||
|
||||
|
||||
class TaskPool:
|
||||
def __init__(self) -> None:
|
||||
self.lock: threading.RLock = threading.RLock()
|
||||
self.pool: dict[UUID, BaseTask] = {}
|
||||
|
||||
def __getitem__(self, uuid: str | UUID) -> BaseTask:
|
||||
with self.lock:
|
||||
if type(uuid) is UUID:
|
||||
return self.pool[uuid]
|
||||
else:
|
||||
uuid = UUID(uuid)
|
||||
return self.pool[uuid]
|
||||
|
||||
|
||||
def __setitem__(self, uuid: UUID, vm: BaseTask) -> None:
|
||||
with self.lock:
|
||||
if uuid in self.pool:
|
||||
raise KeyError(f"VM with uuid {uuid} already exists")
|
||||
if type(uuid) is not UUID:
|
||||
raise TypeError("uuid must be of type UUID")
|
||||
self.pool[uuid] = vm
|
||||
|
||||
|
||||
POOL: TaskPool = TaskPool()
|
||||
|
||||
|
||||
def get_task(uuid: UUID) -> BaseTask:
|
||||
global POOL
|
||||
return POOL[uuid]
|
||||
|
||||
|
||||
def register_task(task: BaseTask, *kwargs) -> UUID:
|
||||
global POOL
|
||||
if not issubclass(task, BaseTask):
|
||||
raise TypeError("task must be a subclass of BaseTask")
|
||||
|
||||
uuid = uuid4()
|
||||
inst_task = task(uuid, *kwargs)
|
||||
POOL[uuid] = inst_task
|
||||
inst_task.start()
|
||||
return uuid
|
||||
@@ -31,14 +31,18 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
graphics=True,
|
||||
),
|
||||
)
|
||||
assert response.status_code == 200, "Failed to inspect vm"
|
||||
assert response.status_code == 200, "Failed to create vm"
|
||||
|
||||
uuid = response.json()["uuid"]
|
||||
assert len(uuid) == 36
|
||||
assert uuid.count("-") == 4
|
||||
|
||||
response = api.get(f"/api/vms/{uuid}/status")
|
||||
assert response.status_code == 200, "Failed to get vm status"
|
||||
|
||||
response = api.get(f"/api/vms/{uuid}/logs")
|
||||
print("=========LOGS==========")
|
||||
for line in response.stream:
|
||||
print(line)
|
||||
|
||||
assert response.status_code == 200, "Failed to get vm status"
|
||||
assert response.status_code == 200, "Failed to get vm logs"
|
||||
Reference in New Issue
Block a user