Skip to content

Commit

Permalink
Allow TypedDict unpacking in Callable types (python#16083)
Browse files Browse the repository at this point in the history
Fixes python#16082

Currently we only allow `Unpack` of a TypedDict when it appears in a
function definition. This PR also allows this in `Callable` types,
similarly to how we do this for variadic types.

Note this still doesn't allow having both variadic unpack and a
TypedDict unpack in the same `Callable`. Supporting this is tricky, so
let's not so this until people will actually ask for this. FWIW we can
always suggest callback protocols for such tricky cases.
  • Loading branch information
ilevkivskyi authored Sep 11, 2023
1 parent 9a35360 commit 9e520c3
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 6 deletions.
4 changes: 3 additions & 1 deletion mypy/exprtotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def expr_to_unanalyzed_type(
elif isinstance(expr, EllipsisExpr):
return EllipsisType(expr.line)
elif allow_unpack and isinstance(expr, StarExpr):
return UnpackType(expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax))
return UnpackType(
expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax), from_star_syntax=True
)
else:
raise TypeTranslationError()
2 changes: 1 addition & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,7 +2041,7 @@ def visit_Attribute(self, n: Attribute) -> Type:

# Used for Callable[[X *Ys, Z], R]
def visit_Starred(self, n: ast3.Starred) -> Type:
return UnpackType(self.visit(n.value))
return UnpackType(self.visit(n.value), from_star_syntax=True)

# List(expr* elts, expr_context ctx)
def visit_List(self, n: ast3.List) -> Type:
Expand Down
4 changes: 3 additions & 1 deletion mypy/semanal_typeargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ def visit_unpack_type(self, typ: UnpackType) -> None:
# Avoid extra errors if there were some errors already. Also interpret plain Any
# as tuple[Any, ...] (this is better for the code in type checker).
self.fail(
message_registry.INVALID_UNPACK.format(format_type(proper_type, self.options)), typ
message_registry.INVALID_UNPACK.format(format_type(proper_type, self.options)),
typ.type,
code=codes.VALID_TYPE,
)
typ.type = self.named_type("builtins.tuple", [AnyType(TypeOfAny.from_error)])

Expand Down
13 changes: 12 additions & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,14 +961,15 @@ def visit_unpack_type(self, t: UnpackType) -> Type:
if not self.allow_unpack:
self.fail(message_registry.INVALID_UNPACK_POSITION, t.type, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
return UnpackType(self.anal_type(t.type))
return UnpackType(self.anal_type(t.type), from_star_syntax=t.from_star_syntax)

def visit_parameters(self, t: Parameters) -> Type:
raise NotImplementedError("ParamSpec literals cannot have unbound TypeVars")

def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
# Every Callable can bind its own type variables, if they're not in the outer scope
with self.tvar_scope_frame():
unpacked_kwargs = False
if self.defining_alias:
variables = t.variables
else:
Expand Down Expand Up @@ -996,6 +997,15 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
)
validated_args.append(AnyType(TypeOfAny.from_error))
else:
if nested and isinstance(at, UnpackType) and i == star_index:
# TODO: it would be better to avoid this get_proper_type() call.
p_at = get_proper_type(at.type)
if isinstance(p_at, TypedDictType) and not at.from_star_syntax:
# Automatically detect Unpack[Foo] in Callable as backwards
# compatible syntax for **Foo, if Foo is a TypedDict.
at = p_at
arg_kinds[i] = ARG_STAR2
unpacked_kwargs = True
validated_args.append(at)
arg_types = validated_args
# If there were multiple (invalid) unpacks, the arg types list will become shorter,
Expand All @@ -1013,6 +1023,7 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
fallback=(t.fallback if t.fallback.type else self.named_type("builtins.function")),
variables=self.anal_var_defs(variables),
type_guard=special,
unpack_kwargs=unpacked_kwargs,
)
return ret

Expand Down
7 changes: 5 additions & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,11 +1053,14 @@ class UnpackType(ProperType):
wild west, technically anything can be present in the wrapped type.
"""

__slots__ = ["type"]
__slots__ = ["type", "from_star_syntax"]

def __init__(self, typ: Type, line: int = -1, column: int = -1) -> None:
def __init__(
self, typ: Type, line: int = -1, column: int = -1, from_star_syntax: bool = False
) -> None:
super().__init__(line, column)
self.type = typ
self.from_star_syntax = from_star_syntax

def accept(self, visitor: TypeVisitor[T]) -> T:
return visitor.visit_unpack_type(self)
Expand Down
15 changes: 15 additions & 0 deletions test-data/unit/check-varargs.test
Original file line number Diff line number Diff line change
Expand Up @@ -1079,3 +1079,18 @@ class C:
class D:
def __init__(self, **kwds: Unpack[int, str]) -> None: ... # E: Unpack[...] requires exactly one type argument
[builtins fixtures/dict.pyi]

[case testUnpackInCallableType]
from typing import Callable
from typing_extensions import Unpack, TypedDict

class TD(TypedDict):
key: str
value: str

foo: Callable[[Unpack[TD]], None]
foo(key="yes", value=42) # E: Argument "value" has incompatible type "int"; expected "str"
foo(key="yes", value="ok")

bad: Callable[[*TD], None] # E: "TD" cannot be unpacked (must be tuple or TypeVarTuple)
[builtins fixtures/dict.pyi]

0 comments on commit 9e520c3

Please sign in to comment.