API: init support for narrowing union types

This allows to relax constraints on functions using overloaded interfaces
I.e. for unifying logic this allows passing 'callable | dict'
Conretely useful for prompt values that are asked on demand in the cli, vs upfront in the ui
This commit is contained in:
Johannes Kirschbauer
2025-08-18 18:06:00 +02:00
parent 3fb8b6587d
commit 1213608f30
2 changed files with 34 additions and 10 deletions

View File

@@ -221,7 +221,9 @@ API.register(get_system_file)
try: try:
serialized_hints = { serialized_hints = {
key: type_to_dict( key: type_to_dict(
value, scope=name + " argument" if key != "return" else "return" value,
scope=name + " argument" if key != "return" else "return",
narrow_unsupported_union_types=True,
) )
for key, value in hints.items() for key, value in hints.items()
} }

View File

@@ -104,7 +104,10 @@ def is_total(typed_dict_class: type) -> bool:
def type_to_dict( def type_to_dict(
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None t: Any,
scope: str = "",
type_map: dict[TypeVar, type] | None = None,
narrow_unsupported_union_types: bool = False,
) -> dict: ) -> dict:
if type_map is None: if type_map is None:
type_map = {} type_map = {}
@@ -164,6 +167,8 @@ def type_to_dict(
dict_properties: dict = {} dict_properties: dict = {}
dict_required: list[str] = [] dict_required: list[str] = []
for field_name, field_type in dict_fields.items(): for field_name, field_type in dict_fields.items():
# Unwrap special case for "NotRequired" and "Required"
# A field type that only exist for TypedDicts
if ( if (
not is_type_in_union(field_type, type(None)) not is_type_in_union(field_type, type(None))
and get_origin(field_type) is not NotRequired and get_origin(field_type) is not NotRequired
@@ -181,9 +186,32 @@ def type_to_dict(
"additionalProperties": False, "additionalProperties": False,
} }
if type(t) is UnionType: origin = get_origin(t)
# UnionTypes
if type(t) is UnionType or origin is Union:
supported = []
for arg in get_args(t):
try:
supported.append(
type_to_dict(arg, scope, type_map, narrow_unsupported_union_types)
)
except JSchemaTypeError:
if narrow_unsupported_union_types:
# If we are narrowing unsupported union types, we skip the error
continue
raise
if len(supported) == 0:
msg = f"{scope} - No supported types in Union {t!s}, type_map: {type_map}"
raise JSchemaTypeError(msg)
if len(supported) == 1:
# If there's only one supported type, return it directly
return supported[0]
# If there are multiple supported types, return them as oneOf
return { return {
"oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__], "oneOf": supported,
} }
if isinstance(t, TypeVar): if isinstance(t, TypeVar):
@@ -221,12 +249,6 @@ def type_to_dict(
schema = type_to_dict(base_type, scope) # Generate schema for the base type schema = type_to_dict(base_type, scope) # Generate schema for the base type
return apply_annotations(schema, metadata) return apply_annotations(schema, metadata)
if origin is Union:
union_types = [type_to_dict(arg, scope, type_map) for arg in t.__args__]
return {
"oneOf": union_types,
}
if origin in {list, set, frozenset, tuple}: if origin in {list, set, frozenset, tuple}:
return { return {
"type": "array", "type": "array",