Skip to content

Commit

Permalink
Fix Concatenate and Generic with ParamSpec substitution (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
Daraan authored Dec 13, 2024
1 parent 700eadd commit ca41832
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 15 deletions.
176 changes: 171 additions & 5 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3705,6 +3705,10 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
self.assertEqual(Y.__parameters__, ())
self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview))

# Regression test; fixing #126 might cause an error here
with self.assertRaisesRegex(TypeError, "not a generic class"):
Y[int]

def test_protocol_generic_over_typevartuple(self):
Ts = TypeVarTuple("Ts")
T = TypeVar("T")
Expand Down Expand Up @@ -5259,6 +5263,7 @@ class X(Generic[T, P]):
class Y(Protocol[T, P]):
pass

things = "arguments" if sys.version_info >= (3, 10) else "parameters"
for klass in X, Y:
with self.subTest(klass=klass.__name__):
G1 = klass[int, P_2]
Expand All @@ -5273,20 +5278,146 @@ class Y(Protocol[T, P]):
self.assertEqual(G3.__args__, (int, Concatenate[int, ...]))
self.assertEqual(G3.__parameters__, ())

with self.assertRaisesRegex(
TypeError,
f"Too few {things} for {klass}"
):
klass[int]

# The following are some valid uses cases in PEP 612 that don't work:
# These do not work in 3.9, _type_check blocks the list and ellipsis.
# G3 = X[int, [int, bool]]
# G4 = X[int, ...]
# G5 = Z[[int, str, bool]]
# Not working because this is special-cased in 3.10.
# G6 = Z[int, str, bool]

def test_single_argument_generic(self):
P = ParamSpec("P")
T = TypeVar("T")
P_2 = ParamSpec("P_2")

class Z(Generic[P]):
pass

class ProtoZ(Protocol[P]):
pass

for klass in Z, ProtoZ:
with self.subTest(klass=klass.__name__):
# Note: For 3.10+ __args__ are nested tuples here ((int, ),) instead of (int, )
G6 = klass[int, str, T]
G6args = G6.__args__[0] if sys.version_info >= (3, 10) else G6.__args__
self.assertEqual(G6args, (int, str, T))
self.assertEqual(G6.__parameters__, (T,))

# P = [int]
G7 = klass[int]
G7args = G7.__args__[0] if sys.version_info >= (3, 10) else G7.__args__
self.assertEqual(G7args, (int,))
self.assertEqual(G7.__parameters__, ())

G8 = klass[Concatenate[T, ...]]
self.assertEqual(G8.__args__, (Concatenate[T, ...], ))
self.assertEqual(G8.__parameters__, (T,))

G9 = klass[Concatenate[T, P_2]]
self.assertEqual(G9.__args__, (Concatenate[T, P_2], ))

# This is an invalid form but useful for testing correct subsitution
G10 = klass[int, Concatenate[str, P]]
G10args = G10.__args__[0] if sys.version_info >= (3, 10) else G10.__args__
self.assertEqual(G10args, (int, Concatenate[str, P], ))

@skipUnless(TYPING_3_10_0, "ParamSpec not present before 3.10")
def test_is_param_expr(self):
P = ParamSpec("P")
P_typing = typing.ParamSpec("P_typing")
self.assertTrue(typing_extensions._is_param_expr(P))
self.assertTrue(typing_extensions._is_param_expr(P_typing))
if hasattr(typing, "_is_param_expr"):
self.assertTrue(typing._is_param_expr(P))
self.assertTrue(typing._is_param_expr(P_typing))

def test_single_argument_generic_with_parameter_expressions(self):
P = ParamSpec("P")
T = TypeVar("T")
P_2 = ParamSpec("P_2")

class Z(Generic[P]):
pass

class ProtoZ(Protocol[P]):
pass

things = "arguments" if sys.version_info >= (3, 10) else "parameters"
for klass in Z, ProtoZ:
with self.subTest(klass=klass.__name__):
G8 = klass[Concatenate[T, ...]]

H8_1 = G8[int]
self.assertEqual(H8_1.__parameters__, ())
with self.assertRaisesRegex(TypeError, "not a generic class"):
H8_1[str]

H8_2 = G8[T][int]
self.assertEqual(H8_2.__parameters__, ())
with self.assertRaisesRegex(TypeError, "not a generic class"):
H8_2[str]

G9 = klass[Concatenate[T, P_2]]
self.assertEqual(G9.__parameters__, (T, P_2))

with self.assertRaisesRegex(TypeError,
"The last parameter to Concatenate should be a ParamSpec variable or ellipsis."
if sys.version_info < (3, 10) else
# from __typing_subst__
"Expected a list of types, an ellipsis, ParamSpec, or Concatenate"
):
G9[int, int]

with self.assertRaisesRegex(TypeError, f"Too few {things}"):
G9[int]

with self.subTest("Check list as parameter expression", klass=klass.__name__):
if sys.version_info < (3, 10):
self.skipTest("Cannot pass non-types")
G5 = klass[[int, str, T]]
self.assertEqual(G5.__parameters__, (T,))
self.assertEqual(G5.__args__, ((int, str, T),))

H9 = G9[int, [T]]
self.assertEqual(H9.__parameters__, (T,))

# This is an invalid parameter expression but useful for testing correct subsitution
G10 = klass[int, Concatenate[str, P]]
with self.subTest("Check invalid form substitution"):
self.assertEqual(G10.__parameters__, (P, ))
if sys.version_info < (3, 9):
self.skipTest("3.8 typing._type_subst does not support this substitution process")
H10 = G10[int]
if (3, 10) <= sys.version_info < (3, 11, 3):
self.skipTest("3.10-3.11.2 does not substitute Concatenate here")
self.assertEqual(H10.__parameters__, ())
H10args = H10.__args__[0] if sys.version_info >= (3, 10) else H10.__args__
self.assertEqual(H10args, (int, (str, int)))

@skipUnless(TYPING_3_10_0, "ParamSpec not present before 3.10")
def test_substitution_with_typing_variants(self):
# verifies substitution and typing._check_generic working with typing variants
P = ParamSpec("P")
typing_P = typing.ParamSpec("typing_P")
typing_Concatenate = typing.Concatenate[int, P]

class Z(Generic[typing_P]):
pass

P1 = Z[typing_P]
self.assertEqual(P1.__parameters__, (typing_P,))
self.assertEqual(P1.__args__, (typing_P,))

C1 = Z[typing_Concatenate]
self.assertEqual(C1.__parameters__, (P,))
self.assertEqual(C1.__args__, (typing_Concatenate,))

def test_pickle(self):
global P, P_co, P_contra, P_default
P = ParamSpec('P')
Expand Down Expand Up @@ -5468,6 +5599,43 @@ def test_eq(self):
self.assertEqual(hash(C4), hash(C5))
self.assertNotEqual(C4, C6)

def test_substitution(self):
T = TypeVar('T')
P = ParamSpec('P')
Ts = TypeVarTuple("Ts")

C1 = Concatenate[str, T, ...]
self.assertEqual(C1[int], Concatenate[str, int, ...])

C2 = Concatenate[str, P]
self.assertEqual(C2[...], Concatenate[str, ...])
self.assertEqual(C2[int], (str, int))
U1 = Unpack[Tuple[int, str]]
U2 = Unpack[Ts]
self.assertEqual(C2[U1], (str, int, str))
self.assertEqual(C2[U2], (str, Unpack[Ts]))
self.assertEqual(C2["U2"], (str, typing.ForwardRef("U2")))

if (3, 12, 0) <= sys.version_info < (3, 12, 4):
with self.assertRaises(AssertionError):
C2[Unpack[U2]]
else:
with self.assertRaisesRegex(TypeError, "must be used with a tuple type"):
C2[Unpack[U2]]

C3 = Concatenate[str, T, P]
self.assertEqual(C3[int, [bool]], (str, int, bool))

@skipUnless(TYPING_3_10_0, "Concatenate not present before 3.10")
def test_is_param_expr(self):
P = ParamSpec('P')
concat = Concatenate[str, P]
typing_concat = typing.Concatenate[str, P]
self.assertTrue(typing_extensions._is_param_expr(concat))
self.assertTrue(typing_extensions._is_param_expr(typing_concat))
if hasattr(typing, "_is_param_expr"):
self.assertTrue(typing._is_param_expr(concat))
self.assertTrue(typing._is_param_expr(typing_concat))

class TypeGuardTests(BaseTestCase):
def test_basics(self):
Expand Down Expand Up @@ -7465,11 +7633,9 @@ def test_callable_with_concatenate(self):
self.assertEqual(callable_concat.__parameters__, (P2,))
concat_usage = callable_concat[str]
with self.subTest("get_args of Concatenate in TypeAliasType"):
if not TYPING_3_9_0:
if not TYPING_3_10_0:
# args are: ([<class 'int'>, ~P2],)
self.skipTest("Nested ParamSpec is not substituted")
if sys.version_info < (3, 10, 2):
self.skipTest("GenericAlias keeps Concatenate in __args__ prior to 3.10.2")
self.assertEqual(get_args(concat_usage), ((int, str),))
with self.subTest("Equality of parameter_expression without []"):
if not TYPING_3_10_0:
Expand Down
Loading

0 comments on commit ca41832

Please sign in to comment.