Extracted threadpool to task_manager.py
This commit is contained in:
@@ -28,8 +28,11 @@ def setup_app() -> FastAPI:
|
|||||||
app.include_router(flake.router)
|
app.include_router(flake.router)
|
||||||
app.include_router(health.router)
|
app.include_router(health.router)
|
||||||
app.include_router(machines.router)
|
app.include_router(machines.router)
|
||||||
app.include_router(root.router)
|
|
||||||
app.include_router(vms.router)
|
app.include_router(vms.router)
|
||||||
|
|
||||||
|
# Needs to be last in register. Because of wildcard route
|
||||||
|
app.include_router(root.router)
|
||||||
|
|
||||||
app.add_exception_handler(vms.NixBuildException, vms.nix_build_exception_handler)
|
app.add_exception_handler(vms.NixBuildException, vms.nix_build_exception_handler)
|
||||||
|
|
||||||
app.mount("/static", StaticFiles(directory=asset_path()), name="static")
|
app.mount("/static", StaticFiles(directory=asset_path()), name="static")
|
||||||
@@ -37,6 +40,7 @@ def setup_app() -> FastAPI:
|
|||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
if isinstance(route, APIRoute):
|
if isinstance(route, APIRoute):
|
||||||
route.operation_id = route.name # in this case, 'read_items'
|
route.operation_id = route.name # in this case, 'read_items'
|
||||||
|
log.debug(f"Registered route: {route}")
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import select
|
|
||||||
import queue
|
|
||||||
import shlex
|
import shlex
|
||||||
|
import io
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import pipes
|
||||||
import threading
|
import threading
|
||||||
from typing import Annotated, AsyncIterator
|
from typing import Annotated, AsyncIterator
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
@@ -23,13 +22,10 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||||||
|
|
||||||
from ...nix import nix_build, nix_eval
|
from ...nix import nix_build, nix_eval
|
||||||
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
|
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
|
||||||
|
from ..task_manager import BaseTask, get_task, register_task
|
||||||
|
|
||||||
# Logging setup
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
||||||
@@ -48,35 +44,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/vms/inspect")
|
|
||||||
async def inspect_vm(
|
|
||||||
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
|
|
||||||
) -> VmInspectResponse:
|
|
||||||
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
cmd[0],
|
|
||||||
*cmd[1:],
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
stdout, stderr = await proc.communicate()
|
|
||||||
|
|
||||||
if proc.returncode != 0:
|
|
||||||
raise NixBuildException(
|
|
||||||
f"""
|
|
||||||
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
|
|
||||||
command: {shlex.join(cmd)}
|
|
||||||
exit code: {proc.returncode}
|
|
||||||
command output:
|
|
||||||
{stderr.decode("utf-8")}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
data = json.loads(stdout)
|
|
||||||
return VmInspectResponse(
|
|
||||||
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NixBuildException(HTTPException):
|
class NixBuildException(HTTPException):
|
||||||
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
|
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
|
||||||
self.uuid = uuid
|
self.uuid = uuid
|
||||||
@@ -93,146 +60,97 @@ class NixBuildException(HTTPException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProcessState:
|
class BuildVmTask(BaseTask):
|
||||||
def __init__(self, proc: subprocess.Popen):
|
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
|
||||||
self.proc: subprocess.Process = proc
|
super().__init__(uuid)
|
||||||
self.stdout: list[str] = []
|
self.vm = vm
|
||||||
self.stderr: list[str] = []
|
|
||||||
self.returncode: int | None = None
|
|
||||||
self.done: bool = False
|
|
||||||
|
|
||||||
class BuildVM(threading.Thread):
|
def run(self) -> None:
|
||||||
def __init__(self, vm: VmConfig, uuid: UUID):
|
|
||||||
# calling parent class constructor
|
|
||||||
threading.Thread.__init__(self)
|
|
||||||
|
|
||||||
# constructor
|
|
||||||
self.vm: VmConfig = vm
|
|
||||||
self.uuid: UUID = uuid
|
|
||||||
self.log = logging.getLogger(__name__)
|
|
||||||
self.procs: list[ProcessState] = []
|
|
||||||
self.failed: bool = False
|
|
||||||
self.finished: bool = False
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
self.log.debug(f"BuildVM with uuid {self.uuid} started")
|
self.log.debug(f"BuildVM with uuid {self.uuid} started")
|
||||||
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
|
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
|
||||||
|
|
||||||
proc = self.run_cmd(cmd)
|
proc = self.run_cmd(cmd)
|
||||||
out = proc.stdout
|
self.log.debug(f"stdout: {proc.stdout}")
|
||||||
self.log.debug(f"out: {out}")
|
|
||||||
|
|
||||||
vm_path = f"{''.join(out)}/bin/run-nixos-vm"
|
vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
|
||||||
self.log.debug(f"vm_path: {vm_path}")
|
self.log.debug(f"vm_path: {vm_path}")
|
||||||
self.run_cmd(vm_path)
|
|
||||||
|
|
||||||
|
self.run_cmd(vm_path)
|
||||||
self.finished = True
|
self.finished = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.failed = True
|
self.failed = True
|
||||||
self.finished = True
|
self.finished = True
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
||||||
def run_cmd(self, cmd: list[str]) -> ProcessState:
|
|
||||||
cwd = os.getcwd()
|
|
||||||
log.debug(f"Working directory: {cwd}")
|
|
||||||
log.debug(f"Running command: {shlex.join(cmd)}")
|
|
||||||
process = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
encoding="utf-8",
|
|
||||||
cwd=cwd,
|
|
||||||
)
|
|
||||||
state = ProcessState(process)
|
|
||||||
self.procs.append(state)
|
|
||||||
|
|
||||||
while process.poll() is None:
|
|
||||||
# Check if stderr is ready to be read from
|
|
||||||
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
|
|
||||||
if process.stderr in rlist:
|
|
||||||
line = process.stderr.readline()
|
|
||||||
state.stderr.append(line)
|
|
||||||
if process.stdout in rlist:
|
|
||||||
line = process.stdout.readline()
|
|
||||||
state.stdout.append(line)
|
|
||||||
|
|
||||||
state.returncode = process.returncode
|
|
||||||
state.done = True
|
|
||||||
|
|
||||||
if process.returncode != 0:
|
|
||||||
raise NixBuildException(
|
|
||||||
self.uuid, f"Failed to run command: {shlex.join(cmd)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
log.debug("Successfully ran command")
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
class VmTaskPool:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.lock: threading.RLock = threading.RLock()
|
|
||||||
self.pool: dict[UUID, BuildVM] = {}
|
|
||||||
|
|
||||||
def __getitem__(self, uuid: str | UUID) -> BuildVM:
|
|
||||||
with self.lock:
|
|
||||||
if type(uuid) is UUID:
|
|
||||||
return self.pool[uuid]
|
|
||||||
else:
|
|
||||||
uuid = UUID(uuid)
|
|
||||||
return self.pool[uuid]
|
|
||||||
|
|
||||||
def __setitem__(self, uuid: UUID, vm: BuildVM) -> None:
|
|
||||||
with self.lock:
|
|
||||||
if uuid in self.pool:
|
|
||||||
raise KeyError(f"VM with uuid {uuid} already exists")
|
|
||||||
if type(uuid) is not UUID:
|
|
||||||
raise TypeError("uuid must be of type UUID")
|
|
||||||
self.pool[uuid] = vm
|
|
||||||
|
|
||||||
|
|
||||||
POOL: VmTaskPool = VmTaskPool()
|
|
||||||
|
|
||||||
|
|
||||||
def nix_build_exception_handler(
|
def nix_build_exception_handler(
|
||||||
request: Request, exc: NixBuildException
|
request: Request, exc: NixBuildException
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
log.error("NixBuildException: %s", exc)
|
log.error("NixBuildException: %s", exc)
|
||||||
# del POOL[exc.uuid]
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=exc.status_code,
|
status_code=exc.status_code,
|
||||||
content=jsonable_encoder(dict(detail=exc.detail)),
|
content=jsonable_encoder(dict(detail=exc.detail)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
##################################
|
||||||
|
# #
|
||||||
|
# ======== VM ROUTES ======== #
|
||||||
|
# #
|
||||||
|
##################################
|
||||||
|
@router.post("/api/vms/inspect")
|
||||||
|
async def inspect_vm(
|
||||||
|
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
|
||||||
|
) -> VmInspectResponse:
|
||||||
|
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
cmd[0],
|
||||||
|
*cmd[1:],
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
stdout, stderr = await proc.communicate()
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise NixBuildException(
|
||||||
|
""
|
||||||
|
f"""
|
||||||
|
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
|
||||||
|
command: {shlex.join(cmd)}
|
||||||
|
exit code: {proc.returncode}
|
||||||
|
command output:
|
||||||
|
{stderr.decode("utf-8")}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
data = json.loads(stdout)
|
||||||
|
return VmInspectResponse(
|
||||||
|
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/vms/{uuid}/status")
|
@router.get("/api/vms/{uuid}/status")
|
||||||
async def get_status(uuid: str) -> VmStatusResponse:
|
async def get_status(uuid: str) -> VmStatusResponse:
|
||||||
global POOL
|
task = get_task(uuid)
|
||||||
handle = POOL[uuid]
|
return VmStatusResponse(running=not task.finished, status=0)
|
||||||
|
|
||||||
if handle.process.poll() is None:
|
|
||||||
return VmStatusResponse(running=True, status=0)
|
|
||||||
else:
|
|
||||||
return VmStatusResponse(running=False, status=handle.process.returncode)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/vms/{uuid}/logs")
|
@router.get("/api/vms/{uuid}/logs")
|
||||||
async def get_logs(uuid: str) -> StreamingResponse:
|
async def get_logs(uuid: str) -> StreamingResponse:
|
||||||
async def stream_logs() -> AsyncIterator[str]:
|
async def stream_logs() -> AsyncIterator[str]:
|
||||||
global POOL
|
task = get_task(uuid)
|
||||||
handle = POOL[uuid]
|
|
||||||
for proc in handle.procs.values():
|
for proc in task.procs:
|
||||||
while True:
|
if proc.done:
|
||||||
if proc.stdout.empty() and proc.stderr.empty() and not proc.done:
|
for line in proc.stderr:
|
||||||
await asyncio.sleep(0.1)
|
yield line
|
||||||
continue
|
|
||||||
if proc.stdout.empty() and proc.stderr.empty() and proc.done:
|
|
||||||
break
|
|
||||||
for line in proc.stdout:
|
for line in proc.stdout:
|
||||||
yield line
|
yield line
|
||||||
for line in proc.stderr:
|
else:
|
||||||
|
while True:
|
||||||
|
if proc.output_pipe.empty() and proc.done:
|
||||||
|
break
|
||||||
|
line = await proc.output_pipe.get()
|
||||||
yield line
|
yield line
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
@@ -240,14 +158,9 @@ async def get_logs(uuid: str) -> StreamingResponse:
|
|||||||
media_type="text/plain",
|
media_type="text/plain",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/vms/create")
|
@router.post("/api/vms/create")
|
||||||
async def create_vm(
|
async def create_vm(
|
||||||
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
|
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
|
||||||
) -> VmCreateResponse:
|
) -> VmCreateResponse:
|
||||||
global POOL
|
uuid = register_task(BuildVmTask, vm)
|
||||||
uuid = uuid4()
|
|
||||||
handle = BuildVM(vm, uuid)
|
|
||||||
handle.start()
|
|
||||||
POOL[uuid] = handle
|
|
||||||
return VmCreateResponse(uuid=str(uuid))
|
return VmCreateResponse(uuid=str(uuid))
|
||||||
|
|||||||
114
pkgs/clan-cli/clan_cli/webui/task_manager.py
Normal file
114
pkgs/clan-cli/clan_cli/webui/task_manager.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import select
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
class CmdState:
|
||||||
|
def __init__(self, proc: subprocess.Popen) -> None:
|
||||||
|
self.proc: subprocess.Process = proc
|
||||||
|
self.stdout: list[str] = []
|
||||||
|
self.stderr: list[str] = []
|
||||||
|
self.output_pipe: asyncio.Queue = asyncio.Queue()
|
||||||
|
self.returncode: int | None = None
|
||||||
|
self.done: bool = False
|
||||||
|
|
||||||
|
class BaseTask(threading.Thread):
|
||||||
|
def __init__(self, uuid: UUID) -> None:
|
||||||
|
# calling parent class constructor
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
|
||||||
|
# constructor
|
||||||
|
self.uuid: UUID = uuid
|
||||||
|
self.log = logging.getLogger(__name__)
|
||||||
|
self.procs: list[CmdState] = []
|
||||||
|
self.failed: bool = False
|
||||||
|
self.finished: bool = False
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
self.finished = True
|
||||||
|
|
||||||
|
def run_cmd(self, cmd: list[str]) -> CmdState:
|
||||||
|
cwd = os.getcwd()
|
||||||
|
self.log.debug(f"Working directory: {cwd}")
|
||||||
|
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
encoding="utf-8",
|
||||||
|
cwd=cwd,
|
||||||
|
)
|
||||||
|
state = CmdState(process)
|
||||||
|
self.procs.append(state)
|
||||||
|
|
||||||
|
while process.poll() is None:
|
||||||
|
# Check if stderr is ready to be read from
|
||||||
|
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
|
||||||
|
if process.stderr in rlist:
|
||||||
|
line = process.stderr.readline()
|
||||||
|
if line != "":
|
||||||
|
state.stderr.append(line.strip('\n'))
|
||||||
|
state.output_pipe.put_nowait(line)
|
||||||
|
if process.stdout in rlist:
|
||||||
|
line = process.stdout.readline()
|
||||||
|
if line != "":
|
||||||
|
state.stdout.append(line.strip('\n'))
|
||||||
|
state.output_pipe.put_nowait(line)
|
||||||
|
|
||||||
|
state.returncode = process.returncode
|
||||||
|
state.done = True
|
||||||
|
|
||||||
|
if process.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to run command: {shlex.join(cmd)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log.debug("Successfully ran command")
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
class TaskPool:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.lock: threading.RLock = threading.RLock()
|
||||||
|
self.pool: dict[UUID, BaseTask] = {}
|
||||||
|
|
||||||
|
def __getitem__(self, uuid: str | UUID) -> BaseTask:
|
||||||
|
with self.lock:
|
||||||
|
if type(uuid) is UUID:
|
||||||
|
return self.pool[uuid]
|
||||||
|
else:
|
||||||
|
uuid = UUID(uuid)
|
||||||
|
return self.pool[uuid]
|
||||||
|
|
||||||
|
|
||||||
|
def __setitem__(self, uuid: UUID, vm: BaseTask) -> None:
|
||||||
|
with self.lock:
|
||||||
|
if uuid in self.pool:
|
||||||
|
raise KeyError(f"VM with uuid {uuid} already exists")
|
||||||
|
if type(uuid) is not UUID:
|
||||||
|
raise TypeError("uuid must be of type UUID")
|
||||||
|
self.pool[uuid] = vm
|
||||||
|
|
||||||
|
|
||||||
|
POOL: TaskPool = TaskPool()
|
||||||
|
|
||||||
|
|
||||||
|
def get_task(uuid: UUID) -> BaseTask:
|
||||||
|
global POOL
|
||||||
|
return POOL[uuid]
|
||||||
|
|
||||||
|
|
||||||
|
def register_task(task: BaseTask, *kwargs) -> UUID:
|
||||||
|
global POOL
|
||||||
|
if not issubclass(task, BaseTask):
|
||||||
|
raise TypeError("task must be a subclass of BaseTask")
|
||||||
|
|
||||||
|
uuid = uuid4()
|
||||||
|
inst_task = task(uuid, *kwargs)
|
||||||
|
POOL[uuid] = inst_task
|
||||||
|
inst_task.start()
|
||||||
|
return uuid
|
||||||
@@ -31,14 +31,18 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
|
|||||||
graphics=True,
|
graphics=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert response.status_code == 200, "Failed to inspect vm"
|
assert response.status_code == 200, "Failed to create vm"
|
||||||
|
|
||||||
uuid = response.json()["uuid"]
|
uuid = response.json()["uuid"]
|
||||||
assert len(uuid) == 36
|
assert len(uuid) == 36
|
||||||
assert uuid.count("-") == 4
|
assert uuid.count("-") == 4
|
||||||
|
|
||||||
response = api.get(f"/api/vms/{uuid}/status")
|
response = api.get(f"/api/vms/{uuid}/status")
|
||||||
|
assert response.status_code == 200, "Failed to get vm status"
|
||||||
|
|
||||||
|
response = api.get(f"/api/vms/{uuid}/logs")
|
||||||
|
print("=========LOGS==========")
|
||||||
for line in response.stream:
|
for line in response.stream:
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
assert response.status_code == 200, "Failed to get vm status"
|
assert response.status_code == 200, "Failed to get vm logs"
|
||||||
Reference in New Issue
Block a user