Merge pull request 'clan-vm-manager: Fix ClanUrl not pickable' (#919) from Qubasa-main into main

This commit is contained in:
clan-bot
2024-03-08 16:51:45 +00:00
8 changed files with 85 additions and 115 deletions

View File

@@ -3,37 +3,37 @@ import dataclasses
import urllib.parse import urllib.parse
import urllib.request import urllib.request
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, member
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from .errors import ClanError from .errors import ClanError
# Define an enum with different members that have different values
class ClanUrl(Enum):
# Use the dataclass decorator to add fields and methods to the members
@member
@dataclass @dataclass
class REMOTE: class FlakeId:
value: str # The url field holds the HTTP URL _value: str | Path
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.value}" # The __str__ method returns a custom string representation return f"{self._value}" # The __str__ method returns a custom string representation
@property
def path(self) -> Path:
assert isinstance(self._value, Path)
return self._value
@property
def url(self) -> str:
assert isinstance(self._value, str)
return self._value
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ClanUrl.REMOTE({self.value})" return f"ClanUrl({self._value})"
@member def is_local(self) -> bool:
@dataclass return isinstance(self._value, Path)
class LOCAL:
value: Path # The path field holds the local path
def __str__(self) -> str: def is_remote(self) -> bool:
return f"{self.value}" # The __str__ method returns a custom string representation return isinstance(self._value, str)
def __repr__(self) -> str:
return f"ClanUrl.LOCAL({self.value})"
# Parameters defined here will be DELETED from the nested uri # Parameters defined here will be DELETED from the nested uri
@@ -45,20 +45,19 @@ class MachineParams:
@dataclass @dataclass
class MachineData: class MachineData:
url: ClanUrl flake_id: FlakeId
name: str = "defaultVM" name: str = "defaultVM"
params: MachineParams = dataclasses.field(default_factory=MachineParams) params: MachineParams = dataclasses.field(default_factory=MachineParams)
def get_id(self) -> str: def get_id(self) -> str:
return f"{self.url}#{self.name}" return f"{self.flake_id}#{self.name}"
# Define the ClanURI class # Define the ClanURI class
class ClanURI: class ClanURI:
_orig_uri: str _orig_uri: str
_nested_uri: str
_components: urllib.parse.ParseResult _components: urllib.parse.ParseResult
url: ClanUrl flake_id: FlakeId
_machines: list[MachineData] _machines: list[MachineData]
# Initialize the class with a clan:// URI # Initialize the class with a clan:// URI
@@ -72,13 +71,13 @@ class ClanURI:
# Check if the URI starts with clan:// # Check if the URI starts with clan://
# If it does, remove the clan:// prefix # If it does, remove the clan:// prefix
if uri.startswith("clan://"): if uri.startswith("clan://"):
self._nested_uri = uri[7:] nested_uri = uri[7:]
else: else:
raise ClanError(f"Invalid uri: expected clan://, got {uri}") raise ClanError(f"Invalid uri: expected clan://, got {uri}")
# Parse the URI into components # Parse the URI into components
# url://netloc/path;parameters?query#fragment # url://netloc/path;parameters?query#fragment
self._components = urllib.parse.urlparse(self._nested_uri) self._components = urllib.parse.urlparse(nested_uri)
# Replace the query string in the components with the new query string # Replace the query string in the components with the new query string
clean_comps = self._components._replace( clean_comps = self._components._replace(
@@ -86,7 +85,7 @@ class ClanURI:
) )
# Parse the URL into a ClanUrl object # Parse the URL into a ClanUrl object
self.url = self._parse_url(clean_comps) self.flake_id = self._parse_url(clean_comps)
# Parse the fragment into a list of machine queries # Parse the fragment into a list of machine queries
# Then parse every machine query into a MachineParameters object # Then parse every machine query into a MachineParameters object
@@ -99,10 +98,10 @@ class ClanURI:
# If there are no machine fragments, add a default machine # If there are no machine fragments, add a default machine
if len(machine_frags) == 0: if len(machine_frags) == 0:
default_machine = MachineData(url=self.url) default_machine = MachineData(flake_id=self.flake_id)
self._machines.append(default_machine) self._machines.append(default_machine)
def _parse_url(self, comps: urllib.parse.ParseResult) -> ClanUrl: def _parse_url(self, comps: urllib.parse.ParseResult) -> FlakeId:
comb = ( comb = (
comps.scheme, comps.scheme,
comps.netloc, comps.netloc,
@@ -113,11 +112,11 @@ class ClanURI:
) )
match comb: match comb:
case ("file", "", path, "", "", _) | ("", "", path, "", "", _): # type: ignore case ("file", "", path, "", "", _) | ("", "", path, "", "", _): # type: ignore
url = ClanUrl.LOCAL.value(Path(path).expanduser().resolve()) # type: ignore flake_id = FlakeId(Path(path).expanduser().resolve())
case _: case _:
url = ClanUrl.REMOTE.value(comps.geturl()) # type: ignore flake_id = FlakeId(comps.geturl())
return url return flake_id
def _parse_machine_query(self, machine_frag: str) -> MachineData: def _parse_machine_query(self, machine_frag: str) -> MachineData:
comp = urllib.parse.urlparse(machine_frag) comp = urllib.parse.urlparse(machine_frag)
@@ -137,7 +136,7 @@ class ClanURI:
# we need to make sure there are no conflicts # we need to make sure there are no conflicts
del query[dfield.name] del query[dfield.name]
params = MachineParams(**machine_params) params = MachineParams(**machine_params)
machine = MachineData(url=self.url, name=machine_name, params=params) machine = MachineData(flake_id=self.flake_id, name=machine_name, params=params)
return machine return machine
@property @property
@@ -148,7 +147,7 @@ class ClanURI:
return self._orig_uri return self._orig_uri
def get_url(self) -> str: def get_url(self) -> str:
return str(self.url) return str(self.flake_id)
@classmethod @classmethod
def from_str( def from_str(

View File

@@ -17,13 +17,6 @@ from ..locked_open import read_history_file, write_history_file
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)
@dataclasses.dataclass @dataclasses.dataclass
class HistoryEntry: class HistoryEntry:
last_used: str last_used: str

View File

@@ -0,0 +1,15 @@
import dataclasses
import json
from typing import Any
class ClanJSONEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
# Check if the object has a to_json method
if hasattr(o, "to_json") and callable(o.to_json):
return o.to_json()
# Check if the object is a dataclass
elif dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
# Otherwise, use the default serialization
return super().default(o)

View File

@@ -1,4 +1,3 @@
import dataclasses
import fcntl import fcntl
import json import json
from collections.abc import Generator from collections.abc import Generator
@@ -6,16 +5,11 @@ from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from clan_cli.jsonrpc import ClanJSONEncoder
from .dirs import user_history_file from .dirs import user_history_file
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)
@contextmanager @contextmanager
def _locked_open(filename: str | Path, mode: str = "r") -> Generator: def _locked_open(filename: str | Path, mode: str = "r") -> Generator:
""" """
@@ -29,7 +23,7 @@ def _locked_open(filename: str | Path, mode: str = "r") -> Generator:
def write_history_file(data: Any) -> None: def write_history_file(data: Any) -> None:
with _locked_open(user_history_file(), "w+") as f: with _locked_open(user_history_file(), "w+") as f:
f.write(json.dumps(data, cls=EnhancedJSONEncoder, indent=4)) f.write(json.dumps(data, cls=ClanJSONEncoder, indent=4))
def read_history_file() -> list[dict]: def read_history_file() -> list[dict]:

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
from clan_cli.clan_uri import ClanURI, ClanUrl, MachineData from clan_cli.clan_uri import ClanURI, MachineData
from clan_cli.dirs import vm_state_dir from clan_cli.dirs import vm_state_dir
from qemu.qmp import QEMUMonitorProtocol from qemu.qmp import QEMUMonitorProtocol
@@ -66,7 +66,7 @@ class Machine:
if machine is None: if machine is None:
uri = ClanURI.from_str(str(flake), name) uri = ClanURI.from_str(str(flake), name)
machine = uri.machine machine = uri.machine
self.flake: str | Path = machine.url.value self.flake: str | Path = machine.flake_id._value
self.name: str = machine.name self.name: str = machine.name
self.data: MachineData = machine self.data: MachineData = machine
else: else:
@@ -77,12 +77,12 @@ class Machine:
self._flake_path: Path | None = None self._flake_path: Path | None = None
self._deployment_info: None | dict[str, str] = deployment_info self._deployment_info: None | dict[str, str] = deployment_info
state_dir = vm_state_dir(flake_url=str(self.data.url), vm_name=self.data.name) state_dir = vm_state_dir(flake_url=str(self.flake), vm_name=self.data.name)
self.vm: QMPWrapper = QMPWrapper(state_dir) self.vm: QMPWrapper = QMPWrapper(state_dir)
def __str__(self) -> str: def __str__(self) -> str:
return f"Machine(name={self.data.name}, flake={self.data.url})" return f"Machine(name={self.data.name}, flake={self.data.flake_id})"
def __repr__(self) -> str: def __repr__(self) -> str:
return str(self) return str(self)
@@ -139,11 +139,12 @@ class Machine:
if self._flake_path: if self._flake_path:
return self._flake_path return self._flake_path
match self.data.url: if self.data.flake_id.is_local():
case ClanUrl.LOCAL.value(path): self._flake_path = self.data.flake_id.path
self._flake_path = path elif self.data.flake_id.is_remote():
case ClanUrl.REMOTE.value(url): self._flake_path = Path(nix_metadata(self.data.flake_id.url)["path"])
self._flake_path = Path(nix_metadata(url)["path"]) else:
raise ClanError(f"Unsupported flake url: {self.data.flake_id}")
assert self._flake_path is not None assert self._flake_path is not None
return self._flake_path return self._flake_path

View File

@@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from clan_cli.clan_uri import ClanURI, ClanUrl from clan_cli.clan_uri import ClanURI
def test_get_url() -> None: def test_get_url() -> None:
@@ -21,22 +21,13 @@ def test_get_url() -> None:
def test_local_uri() -> None: def test_local_uri() -> None:
# Create a ClanURI object from a local URI # Create a ClanURI object from a local URI
uri = ClanURI("clan://file:///home/user/Downloads") uri = ClanURI("clan://file:///home/user/Downloads")
match uri.url: assert uri.flake_id.path == Path("/home/user/Downloads")
case ClanUrl.LOCAL.value(path):
assert path == Path("/home/user/Downloads") # type: ignore
case _:
assert False
def test_is_remote() -> None: def test_is_remote() -> None:
# Create a ClanURI object from a remote URI # Create a ClanURI object from a remote URI
uri = ClanURI("clan://https://example.com") uri = ClanURI("clan://https://example.com")
assert uri.flake_id.url == "https://example.com"
match uri.url:
case ClanUrl.REMOTE.value(url):
assert url == "https://example.com" # type: ignore
case _:
assert False
def test_direct_local_path() -> None: def test_direct_local_path() -> None:
@@ -56,12 +47,7 @@ def test_remote_with_clanparams() -> None:
uri = ClanURI("clan://https://example.com") uri = ClanURI("clan://https://example.com")
assert uri.machine.name == "defaultVM" assert uri.machine.name == "defaultVM"
assert uri.flake_id.url == "https://example.com"
match uri.url:
case ClanUrl.REMOTE.value(url):
assert url == "https://example.com" # type: ignore
case _:
assert False
def test_remote_with_all_params() -> None: def test_remote_with_all_params() -> None:
@@ -69,11 +55,7 @@ def test_remote_with_all_params() -> None:
assert uri.machine.name == "myVM" assert uri.machine.name == "myVM"
assert uri._machines[1].name == "secondVM" assert uri._machines[1].name == "secondVM"
assert uri._machines[1].params.dummy_opt == "1" assert uri._machines[1].params.dummy_opt == "1"
match uri.url: assert uri.flake_id.url == "https://example.com?password=12345"
case ClanUrl.REMOTE.value(url):
assert url == "https://example.com?password=12345" # type: ignore
case _:
assert False
def test_from_str_remote() -> None: def test_from_str_remote() -> None:
@@ -82,11 +64,7 @@ def test_from_str_remote() -> None:
assert uri.get_orig_uri() == "clan://https://example.com#myVM" assert uri.get_orig_uri() == "clan://https://example.com#myVM"
assert uri.machine.name == "myVM" assert uri.machine.name == "myVM"
assert len(uri._machines) == 1 assert len(uri._machines) == 1
match uri.url: assert uri.flake_id.url == "https://example.com"
case ClanUrl.REMOTE.value(url):
assert url == "https://example.com" # type: ignore
case _:
assert False
def test_from_str_local() -> None: def test_from_str_local() -> None:
@@ -95,11 +73,8 @@ def test_from_str_local() -> None:
assert uri.get_orig_uri() == "clan://~/Projects/democlan#myVM" assert uri.get_orig_uri() == "clan://~/Projects/democlan#myVM"
assert uri.machine.name == "myVM" assert uri.machine.name == "myVM"
assert len(uri._machines) == 1 assert len(uri._machines) == 1
match uri.url: assert uri.flake_id.is_local()
case ClanUrl.LOCAL.value(path): assert str(uri.flake_id).endswith("/Projects/democlan") # type: ignore
assert str(path).endswith("/Projects/democlan") # type: ignore
case _:
assert False
def test_from_str_local_no_machine() -> None: def test_from_str_local_no_machine() -> None:
@@ -108,11 +83,8 @@ def test_from_str_local_no_machine() -> None:
assert uri.get_orig_uri() == "clan://~/Projects/democlan" assert uri.get_orig_uri() == "clan://~/Projects/democlan"
assert uri.machine.name == "defaultVM" assert uri.machine.name == "defaultVM"
assert len(uri._machines) == 1 assert len(uri._machines) == 1
match uri.url: assert uri.flake_id.is_local()
case ClanUrl.LOCAL.value(path): assert str(uri.flake_id).endswith("/Projects/democlan") # type: ignore
assert str(path).endswith("/Projects/democlan") # type: ignore
case _:
assert False
def test_from_str_local_no_machine2() -> None: def test_from_str_local_no_machine2() -> None:
@@ -121,8 +93,5 @@ def test_from_str_local_no_machine2() -> None:
assert uri.get_orig_uri() == "clan://~/Projects/democlan#syncthing-peer1" assert uri.get_orig_uri() == "clan://~/Projects/democlan#syncthing-peer1"
assert uri.machine.name == "syncthing-peer1" assert uri.machine.name == "syncthing-peer1"
assert len(uri._machines) == 1 assert len(uri._machines) == 1
match uri.url: assert uri.flake_id.is_local()
case ClanUrl.LOCAL.value(path): assert str(uri.flake_id).endswith("/Projects/democlan") # type: ignore
assert str(path).endswith("/Projects/democlan") # type: ignore
case _:
assert False

View File

@@ -13,7 +13,7 @@ from typing import IO, ClassVar
import gi import gi
from clan_cli import vms from clan_cli import vms
from clan_cli.clan_uri import ClanURI, ClanUrl from clan_cli.clan_uri import ClanURI
from clan_cli.history.add import HistoryEntry from clan_cli.history.add import HistoryEntry
from clan_cli.machines.machines import Machine from clan_cli.machines.machines import Machine
@@ -115,16 +115,15 @@ class VMObject(GObject.Object):
uri = ClanURI.from_str( uri = ClanURI.from_str(
url=self.data.flake.flake_url, machine_name=self.data.flake.flake_attr url=self.data.flake.flake_url, machine_name=self.data.flake.flake_attr
) )
match uri.url: if uri.flake_id.is_local():
case ClanUrl.LOCAL.value(path):
self.machine = Machine( self.machine = Machine(
name=self.data.flake.flake_attr, name=self.data.flake.flake_attr,
flake=path, # type: ignore flake=uri.flake_id.path,
) )
case ClanUrl.REMOTE.value(url): if uri.flake_id.is_remote():
self.machine = Machine( self.machine = Machine(
name=self.data.flake.flake_attr, name=self.data.flake.flake_attr,
flake=url, # type: ignore flake=uri.flake_id.url,
) )
yield self.machine yield self.machine
self.machine = None self.machine = None

View File

@@ -111,7 +111,7 @@ class ClanStore:
del self.clan_store[vm.data.flake.flake_url][vm.data.flake.flake_attr] del self.clan_store[vm.data.flake.flake_url][vm.data.flake.flake_attr]
def get_vm(self, uri: ClanURI) -> None | VMObject: def get_vm(self, uri: ClanURI) -> None | VMObject:
vm_store = self.clan_store.get(str(uri.url)) vm_store = self.clan_store.get(str(uri.flake_id))
if vm_store is None: if vm_store is None:
return None return None
machine = vm_store.get(uri.machine.name, None) machine = vm_store.get(uri.machine.name, None)