Extracted threadpool to task_manager.py

This commit is contained in:
Qubasa
2023-09-26 19:36:01 +02:00
committed by Mic92
parent b535f745e0
commit c96b339d94
4 changed files with 185 additions and 150 deletions

View File

@@ -28,8 +28,11 @@ def setup_app() -> FastAPI:
app.include_router(flake.router) app.include_router(flake.router)
app.include_router(health.router) app.include_router(health.router)
app.include_router(machines.router) app.include_router(machines.router)
app.include_router(root.router)
app.include_router(vms.router) app.include_router(vms.router)
# Needs to be last in register. Because of wildcard route
app.include_router(root.router)
app.add_exception_handler(vms.NixBuildException, vms.nix_build_exception_handler) app.add_exception_handler(vms.NixBuildException, vms.nix_build_exception_handler)
app.mount("/static", StaticFiles(directory=asset_path()), name="static") app.mount("/static", StaticFiles(directory=asset_path()), name="static")
@@ -37,6 +40,7 @@ def setup_app() -> FastAPI:
for route in app.routes: for route in app.routes:
if isinstance(route, APIRoute): if isinstance(route, APIRoute):
route.operation_id = route.name # in this case, 'read_items' route.operation_id = route.name # in this case, 'read_items'
log.debug(f"Registered route: {route}")
return app return app

View File

@@ -1,11 +1,10 @@
import asyncio import asyncio
import json import json
import logging import logging
import os
import select
import queue
import shlex import shlex
import io
import subprocess import subprocess
import pipes
import threading import threading
from typing import Annotated, AsyncIterator from typing import Annotated, AsyncIterator
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@@ -23,13 +22,10 @@ from fastapi.responses import JSONResponse, StreamingResponse
from ...nix import nix_build, nix_eval from ...nix import nix_build, nix_eval
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
from ..task_manager import BaseTask, get_task, register_task
# Logging setup
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
app = FastAPI()
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]: def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
@@ -48,35 +44,6 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
) )
@router.post("/api/vms/inspect")
async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
proc = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
f"""
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
data = json.loads(stdout)
return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
)
class NixBuildException(HTTPException): class NixBuildException(HTTPException):
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]): def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
self.uuid = uuid self.uuid = uuid
@@ -93,146 +60,97 @@ class NixBuildException(HTTPException):
) )
class ProcessState: class BuildVmTask(BaseTask):
def __init__(self, proc: subprocess.Popen): def __init__(self, uuid: UUID, vm: VmConfig) -> None:
self.proc: subprocess.Process = proc super().__init__(uuid)
self.stdout: list[str] = [] self.vm = vm
self.stderr: list[str] = []
self.returncode: int | None = None
self.done: bool = False
class BuildVM(threading.Thread): def run(self) -> None:
def __init__(self, vm: VmConfig, uuid: UUID):
# calling parent class constructor
threading.Thread.__init__(self)
# constructor
self.vm: VmConfig = vm
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[ProcessState] = []
self.failed: bool = False
self.finished: bool = False
def run(self):
try: try:
self.log.debug(f"BuildVM with uuid {self.uuid} started") self.log.debug(f"BuildVM with uuid {self.uuid} started")
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url) cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
proc = self.run_cmd(cmd) proc = self.run_cmd(cmd)
out = proc.stdout self.log.debug(f"stdout: {proc.stdout}")
self.log.debug(f"out: {out}")
vm_path = f"{''.join(out)}/bin/run-nixos-vm" vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
self.log.debug(f"vm_path: {vm_path}") self.log.debug(f"vm_path: {vm_path}")
self.run_cmd(vm_path)
self.run_cmd(vm_path)
self.finished = True self.finished = True
except Exception as e: except Exception as e:
self.failed = True self.failed = True
self.finished = True self.finished = True
log.exception(e) log.exception(e)
def run_cmd(self, cmd: list[str]) -> ProcessState:
cwd = os.getcwd()
log.debug(f"Working directory: {cwd}")
log.debug(f"Running command: {shlex.join(cmd)}")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=cwd,
)
state = ProcessState(process)
self.procs.append(state)
while process.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
if process.stderr in rlist:
line = process.stderr.readline()
state.stderr.append(line)
if process.stdout in rlist:
line = process.stdout.readline()
state.stdout.append(line)
state.returncode = process.returncode
state.done = True
if process.returncode != 0:
raise NixBuildException(
self.uuid, f"Failed to run command: {shlex.join(cmd)}"
)
log.debug("Successfully ran command")
return state
class VmTaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BuildVM] = {}
def __getitem__(self, uuid: str | UUID) -> BuildVM:
with self.lock:
if type(uuid) is UUID:
return self.pool[uuid]
else:
uuid = UUID(uuid)
return self.pool[uuid]
def __setitem__(self, uuid: UUID, vm: BuildVM) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"VM with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = vm
POOL: VmTaskPool = VmTaskPool()
def nix_build_exception_handler( def nix_build_exception_handler(
request: Request, exc: NixBuildException request: Request, exc: NixBuildException
) -> JSONResponse: ) -> JSONResponse:
log.error("NixBuildException: %s", exc) log.error("NixBuildException: %s", exc)
# del POOL[exc.uuid]
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content=jsonable_encoder(dict(detail=exc.detail)), content=jsonable_encoder(dict(detail=exc.detail)),
) )
##################################
# #
# ======== VM ROUTES ======== #
# #
##################################
@router.post("/api/vms/inspect")
async def inspect_vm(
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
) -> VmInspectResponse:
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
proc = await asyncio.create_subprocess_exec(
cmd[0],
*cmd[1:],
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise NixBuildException(
""
f"""
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
command: {shlex.join(cmd)}
exit code: {proc.returncode}
command output:
{stderr.decode("utf-8")}
"""
)
data = json.loads(stdout)
return VmInspectResponse(
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
)
@router.get("/api/vms/{uuid}/status") @router.get("/api/vms/{uuid}/status")
async def get_status(uuid: str) -> VmStatusResponse: async def get_status(uuid: str) -> VmStatusResponse:
global POOL task = get_task(uuid)
handle = POOL[uuid] return VmStatusResponse(running=not task.finished, status=0)
if handle.process.poll() is None:
return VmStatusResponse(running=True, status=0)
else:
return VmStatusResponse(running=False, status=handle.process.returncode)
@router.get("/api/vms/{uuid}/logs") @router.get("/api/vms/{uuid}/logs")
async def get_logs(uuid: str) -> StreamingResponse: async def get_logs(uuid: str) -> StreamingResponse:
async def stream_logs() -> AsyncIterator[str]: async def stream_logs() -> AsyncIterator[str]:
global POOL task = get_task(uuid)
handle = POOL[uuid]
for proc in handle.procs.values(): for proc in task.procs:
while True: if proc.done:
if proc.stdout.empty() and proc.stderr.empty() and not proc.done: for line in proc.stderr:
await asyncio.sleep(0.1) yield line
continue
if proc.stdout.empty() and proc.stderr.empty() and proc.done:
break
for line in proc.stdout: for line in proc.stdout:
yield line yield line
for line in proc.stderr: else:
while True:
if proc.output_pipe.empty() and proc.done:
break
line = await proc.output_pipe.get()
yield line yield line
return StreamingResponse( return StreamingResponse(
@@ -240,14 +158,9 @@ async def get_logs(uuid: str) -> StreamingResponse:
media_type="text/plain", media_type="text/plain",
) )
@router.post("/api/vms/create") @router.post("/api/vms/create")
async def create_vm( async def create_vm(
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
) -> VmCreateResponse: ) -> VmCreateResponse:
global POOL uuid = register_task(BuildVmTask, vm)
uuid = uuid4()
handle = BuildVM(vm, uuid)
handle.start()
POOL[uuid] = handle
return VmCreateResponse(uuid=str(uuid)) return VmCreateResponse(uuid=str(uuid))

View File

@@ -0,0 +1,114 @@
import logging
import os
import queue
import select
import shlex
import subprocess
import threading
from uuid import UUID, uuid4
class CmdState:
def __init__(self, proc: subprocess.Popen) -> None:
self.proc: subprocess.Process = proc
self.stdout: list[str] = []
self.stderr: list[str] = []
self.output_pipe: asyncio.Queue = asyncio.Queue()
self.returncode: int | None = None
self.done: bool = False
class BaseTask(threading.Thread):
def __init__(self, uuid: UUID) -> None:
# calling parent class constructor
threading.Thread.__init__(self)
# constructor
self.uuid: UUID = uuid
self.log = logging.getLogger(__name__)
self.procs: list[CmdState] = []
self.failed: bool = False
self.finished: bool = False
def run(self) -> None:
self.finished = True
def run_cmd(self, cmd: list[str]) -> CmdState:
cwd = os.getcwd()
self.log.debug(f"Working directory: {cwd}")
self.log.debug(f"Running command: {shlex.join(cmd)}")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
cwd=cwd,
)
state = CmdState(process)
self.procs.append(state)
while process.poll() is None:
# Check if stderr is ready to be read from
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0)
if process.stderr in rlist:
line = process.stderr.readline()
if line != "":
state.stderr.append(line.strip('\n'))
state.output_pipe.put_nowait(line)
if process.stdout in rlist:
line = process.stdout.readline()
if line != "":
state.stdout.append(line.strip('\n'))
state.output_pipe.put_nowait(line)
state.returncode = process.returncode
state.done = True
if process.returncode != 0:
raise RuntimeError(
f"Failed to run command: {shlex.join(cmd)}"
)
self.log.debug("Successfully ran command")
return state
class TaskPool:
def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: str | UUID) -> BaseTask:
with self.lock:
if type(uuid) is UUID:
return self.pool[uuid]
else:
uuid = UUID(uuid)
return self.pool[uuid]
def __setitem__(self, uuid: UUID, vm: BaseTask) -> None:
with self.lock:
if uuid in self.pool:
raise KeyError(f"VM with uuid {uuid} already exists")
if type(uuid) is not UUID:
raise TypeError("uuid must be of type UUID")
self.pool[uuid] = vm
POOL: TaskPool = TaskPool()
def get_task(uuid: UUID) -> BaseTask:
global POOL
return POOL[uuid]
def register_task(task: BaseTask, *kwargs) -> UUID:
global POOL
if not issubclass(task, BaseTask):
raise TypeError("task must be a subclass of BaseTask")
uuid = uuid4()
inst_task = task(uuid, *kwargs)
POOL[uuid] = inst_task
inst_task.start()
return uuid

View File

@@ -31,14 +31,18 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
graphics=True, graphics=True,
), ),
) )
assert response.status_code == 200, "Failed to inspect vm" assert response.status_code == 200, "Failed to create vm"
uuid = response.json()["uuid"] uuid = response.json()["uuid"]
assert len(uuid) == 36 assert len(uuid) == 36
assert uuid.count("-") == 4 assert uuid.count("-") == 4
response = api.get(f"/api/vms/{uuid}/status") response = api.get(f"/api/vms/{uuid}/status")
assert response.status_code == 200, "Failed to get vm status"
response = api.get(f"/api/vms/{uuid}/logs")
print("=========LOGS==========")
for line in response.stream: for line in response.stream:
print(line) print(line)
assert response.status_code == 200, "Failed to get vm status" assert response.status_code == 200, "Failed to get vm logs"