API: init support for narrowing union types
This allows to relax constraints on functions using overloaded interfaces I.e. for unifying logic this allows passing 'callable | dict' Conretely useful for prompt values that are asked on demand in the cli, vs upfront in the ui
This commit is contained in:
@@ -221,7 +221,9 @@ API.register(get_system_file)
|
|||||||
try:
|
try:
|
||||||
serialized_hints = {
|
serialized_hints = {
|
||||||
key: type_to_dict(
|
key: type_to_dict(
|
||||||
value, scope=name + " argument" if key != "return" else "return"
|
value,
|
||||||
|
scope=name + " argument" if key != "return" else "return",
|
||||||
|
narrow_unsupported_union_types=True,
|
||||||
)
|
)
|
||||||
for key, value in hints.items()
|
for key, value in hints.items()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,7 +104,10 @@ def is_total(typed_dict_class: type) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def type_to_dict(
|
def type_to_dict(
|
||||||
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
|
t: Any,
|
||||||
|
scope: str = "",
|
||||||
|
type_map: dict[TypeVar, type] | None = None,
|
||||||
|
narrow_unsupported_union_types: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if type_map is None:
|
if type_map is None:
|
||||||
type_map = {}
|
type_map = {}
|
||||||
@@ -164,6 +167,8 @@ def type_to_dict(
|
|||||||
dict_properties: dict = {}
|
dict_properties: dict = {}
|
||||||
dict_required: list[str] = []
|
dict_required: list[str] = []
|
||||||
for field_name, field_type in dict_fields.items():
|
for field_name, field_type in dict_fields.items():
|
||||||
|
# Unwrap special case for "NotRequired" and "Required"
|
||||||
|
# A field type that only exist for TypedDicts
|
||||||
if (
|
if (
|
||||||
not is_type_in_union(field_type, type(None))
|
not is_type_in_union(field_type, type(None))
|
||||||
and get_origin(field_type) is not NotRequired
|
and get_origin(field_type) is not NotRequired
|
||||||
@@ -181,9 +186,32 @@ def type_to_dict(
|
|||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if type(t) is UnionType:
|
origin = get_origin(t)
|
||||||
|
# UnionTypes
|
||||||
|
if type(t) is UnionType or origin is Union:
|
||||||
|
supported = []
|
||||||
|
for arg in get_args(t):
|
||||||
|
try:
|
||||||
|
supported.append(
|
||||||
|
type_to_dict(arg, scope, type_map, narrow_unsupported_union_types)
|
||||||
|
)
|
||||||
|
except JSchemaTypeError:
|
||||||
|
if narrow_unsupported_union_types:
|
||||||
|
# If we are narrowing unsupported union types, we skip the error
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
if len(supported) == 0:
|
||||||
|
msg = f"{scope} - No supported types in Union {t!s}, type_map: {type_map}"
|
||||||
|
raise JSchemaTypeError(msg)
|
||||||
|
|
||||||
|
if len(supported) == 1:
|
||||||
|
# If there's only one supported type, return it directly
|
||||||
|
return supported[0]
|
||||||
|
|
||||||
|
# If there are multiple supported types, return them as oneOf
|
||||||
return {
|
return {
|
||||||
"oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__],
|
"oneOf": supported,
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(t, TypeVar):
|
if isinstance(t, TypeVar):
|
||||||
@@ -221,12 +249,6 @@ def type_to_dict(
|
|||||||
schema = type_to_dict(base_type, scope) # Generate schema for the base type
|
schema = type_to_dict(base_type, scope) # Generate schema for the base type
|
||||||
return apply_annotations(schema, metadata)
|
return apply_annotations(schema, metadata)
|
||||||
|
|
||||||
if origin is Union:
|
|
||||||
union_types = [type_to_dict(arg, scope, type_map) for arg in t.__args__]
|
|
||||||
return {
|
|
||||||
"oneOf": union_types,
|
|
||||||
}
|
|
||||||
|
|
||||||
if origin in {list, set, frozenset, tuple}:
|
if origin in {list, set, frozenset, tuple}:
|
||||||
return {
|
return {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
|||||||
Reference in New Issue
Block a user