vars: refactor - unify get_generators and _get_closure

This commit is contained in:
DavHau
2025-08-13 14:45:54 +07:00
parent aaeb616f82
commit a535450ec0
5 changed files with 27 additions and 34 deletions

View File

@@ -427,12 +427,25 @@ def _get_previous_value(
return None
def _get_closure(
@API.register
def get_generators(
machine: "Machine",
generator_name: str | None,
full_closure: bool,
generator_name: str | None = None,
include_previous_values: bool = False,
) -> list[Generator]:
"""
Get generators for a machine, with optional closure computation.
Args:
machine: The machine to get generators for.
full_closure: If True, include all dependency generators. If False, only include missing ones.
generator_name: Name of a specific generator to get, or None for all generators.
include_previous_values: If True, populate prompts with their previous values.
Returns:
List of generators based on the specified selection and closure mode.
"""
from . import graph
vars_generators = Generator.get_machine_generators(machine.name, machine.flake)
@@ -510,31 +523,6 @@ def _generate_vars_for_machine(
)
@API.register
def get_generators(
machine_name: str,
base_dir: Path,
include_previous_values: bool = False,
) -> list[Generator]:
"""
Get the list of generators for a machine, optionally with previous values.
If `full_closure` is True, it returns the full closure of generators.
If `include_previous_values` is True, it includes the previous values for prompts.
Args:
machine_name (str): The name of the machine.
base_dir (Path): The base directory of the flake.
Returns:
list[Generator]: A list of generators for the machine.
"""
return Generator.get_machine_generators(
machine_name,
Flake(str(base_dir)),
include_previous_values,
)
@API.register
def run_generators(
machine_name: str,
@@ -585,7 +573,7 @@ def create_machine_vars_interactive(
regenerate: bool,
no_sandbox: bool = False,
) -> None:
generators = _get_closure(machine, generator_name, regenerate)
generators = get_generators(machine, regenerate, generator_name)
if len(generators) == 0:
return
all_prompt_values = {}