Working log streaming

This commit is contained in:
Qubasa
2023-09-27 01:52:38 +02:00
committed by Mic92
parent 3a11c0a746
commit 98028d121f
3 changed files with 71 additions and 64 deletions

View File

@@ -2,12 +2,8 @@ import asyncio
import json import json
import logging import logging
import shlex import shlex
import io from typing import Annotated
import subprocess from uuid import UUID
import pipes
import threading
from typing import Annotated, AsyncIterator
from uuid import UUID, uuid4
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
@@ -76,7 +72,7 @@ class BuildVmTask(BaseTask):
vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm" vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
self.log.debug(f"vm_path: {vm_path}") self.log.debug(f"vm_path: {vm_path}")
self.run_cmd(vm_path) #self.run_cmd(vm_path)
self.finished = True self.finished = True
except Exception as e: except Exception as e:
self.failed = True self.failed = True
@@ -137,20 +133,23 @@ async def get_status(uuid: str) -> VmStatusResponse:
@router.get("/api/vms/{uuid}/logs") @router.get("/api/vms/{uuid}/logs")
async def get_logs(uuid: str) -> StreamingResponse: async def get_logs(uuid: str) -> StreamingResponse:
async def stream_logs() -> AsyncIterator[str]: def stream_logs():
task = get_task(uuid) task = get_task(uuid)
for proc in task.procs: for proc in task.procs:
if proc.done: if proc.done:
log.debug("stream logs and proc is done")
for line in proc.stderr: for line in proc.stderr:
yield line yield line + "\n"
for line in proc.stdout: for line in proc.stdout:
yield line yield line + "\n"
else: break
while True: while True:
if proc.output_pipe.empty() and proc.done: out = proc.output
line = out.get()
if line is None:
break break
line = await proc.output_pipe.get()
yield line yield line
return StreamingResponse( return StreamingResponse(

View File

@@ -7,15 +7,18 @@ import subprocess
import threading import threading
from uuid import UUID, uuid4 from uuid import UUID, uuid4
class CmdState: class CmdState:
def __init__(self, proc: subprocess.Popen) -> None: def __init__(self, proc: subprocess.Popen) -> None:
self.proc: subprocess.Process = proc global LOOP
self.proc: subprocess.Popen = proc
self.stdout: list[str] = [] self.stdout: list[str] = []
self.stderr: list[str] = [] self.stderr: list[str] = []
self.output_pipe: asyncio.Queue = asyncio.Queue() self.output: queue.SimpleQueue = queue.SimpleQueue()
self.returncode: int | None = None self.returncode: int | None = None
self.done: bool = False self.done: bool = False
class BaseTask(threading.Thread): class BaseTask(threading.Thread):
def __init__(self, uuid: UUID) -> None: def __init__(self, uuid: UUID) -> None:
# calling parent class constructor # calling parent class constructor
@@ -35,49 +38,52 @@ class BaseTask(threading.Thread):
cwd = os.getcwd() cwd = os.getcwd()
self.log.debug(f"Working directory: {cwd}") self.log.debug(f"Working directory: {cwd}")
self.log.debug(f"Running command: {shlex.join(cmd)}") self.log.debug(f"Running command: {shlex.join(cmd)}")
process = subprocess.Popen( p = subprocess.Popen(
cmd, cmd,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
encoding="utf-8", encoding="utf-8",
# shell=True,
cwd=cwd, cwd=cwd,
) )
state = CmdState(process) self.procs.append(CmdState(p))
self.procs.append(state) p_state = self.procs[-1]
while process.poll() is None: while p.poll() is None:
# Check if stderr is ready to be read from # Check if stderr is ready to be read from
rlist, _, _ = select.select([process.stderr, process.stdout], [], [], 0) rlist, _, _ = select.select([p.stderr, p.stdout], [], [], 0)
if process.stderr in rlist: if p.stderr in rlist:
line = process.stderr.readline() line = p.stderr.readline()
if line != "": if line != "":
state.stderr.append(line.strip('\n')) p_state.stderr.append(line.strip("\n"))
state.output_pipe.put_nowait(line) self.log.debug(f"stderr: {line}")
if process.stdout in rlist: p_state.output.put(line)
line = process.stdout.readline()
if p.stdout in rlist:
line = p.stdout.readline()
if line != "": if line != "":
state.stdout.append(line.strip('\n')) p_state.stdout.append(line.strip("\n"))
state.output_pipe.put_nowait(line) self.log.debug(f"stdout: {line}")
p_state.output.put(line)
state.returncode = process.returncode p_state.returncode = p.returncode
state.done = True p_state.output.put(None)
p_state.done = True
if process.returncode != 0: if p.returncode != 0:
raise RuntimeError( raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
f"Failed to run command: {shlex.join(cmd)}"
)
self.log.debug("Successfully ran command") self.log.debug("Successfully ran command")
return state return p_state
class TaskPool: class TaskPool:
def __init__(self) -> None: def __init__(self) -> None:
self.lock: threading.RLock = threading.RLock() # self.lock: threading.RLock = threading.RLock()
self.pool: dict[UUID, BaseTask] = {} self.pool: dict[UUID, BaseTask] = {}
def __getitem__(self, uuid: str | UUID) -> BaseTask: def __getitem__(self, uuid: str | UUID) -> BaseTask:
with self.lock: # with self.lock:
if type(uuid) is UUID: if type(uuid) is UUID:
return self.pool[uuid] return self.pool[uuid]
else: else:
@@ -86,7 +92,7 @@ class TaskPool:
def __setitem__(self, uuid: UUID, vm: BaseTask) -> None: def __setitem__(self, uuid: UUID, vm: BaseTask) -> None:
with self.lock: # with self.lock:
if uuid in self.pool: if uuid in self.pool:
raise KeyError(f"VM with uuid {uuid} already exists") raise KeyError(f"VM with uuid {uuid} already exists")
if type(uuid) is not UUID: if type(uuid) is not UUID:
@@ -108,6 +114,7 @@ def register_task(task: BaseTask, *kwargs) -> UUID:
raise TypeError("task must be a subclass of BaseTask") raise TypeError("task must be a subclass of BaseTask")
uuid = uuid4() uuid = uuid4()
inst_task = task(uuid, *kwargs) inst_task = task(uuid, *kwargs)
POOL[uuid] = inst_task POOL[uuid] = inst_task
inst_task.start() inst_task.start()

View File

@@ -4,18 +4,18 @@ import pytest
from api import TestClient from api import TestClient
@pytest.mark.impure # @pytest.mark.impure
def test_inspect(api: TestClient, test_flake_with_core: Path) -> None: # def test_inspect(api: TestClient, test_flake_with_core: Path) -> None:
response = api.post( # response = api.post(
"/api/vms/inspect", # "/api/vms/inspect",
json=dict(flake_url=str(test_flake_with_core), flake_attr="vm1"), # json=dict(flake_url=str(test_flake_with_core), flake_attr="vm1"),
) # )
assert response.status_code == 200, "Failed to inspect vm" # assert response.status_code == 200, "Failed to inspect vm"
config = response.json()["config"] # config = response.json()["config"]
assert config.get("flake_attr") == "vm1" # assert config.get("flake_attr") == "vm1"
assert config.get("cores") == 1 # assert config.get("cores") == 1
assert config.get("memory_size") == 1024 # assert config.get("memory_size") == 1024
assert config.get("graphics") is True # assert config.get("graphics") is True
@pytest.mark.impure @pytest.mark.impure
@@ -43,6 +43,7 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
response = api.get(f"/api/vms/{uuid}/logs") response = api.get(f"/api/vms/{uuid}/logs")
print("=========LOGS==========") print("=========LOGS==========")
for line in response.stream: for line in response.stream:
print(line) print(f"line: {line}")
assert line != b"", "Failed to get vm logs"
assert response.status_code == 200, "Failed to get vm logs" assert response.status_code == 200, "Failed to get vm logs"