diff --git a/docs/guide/builtins.rst b/docs/guide/builtins.rst index 6f2b1484..c588187d 100644 --- a/docs/guide/builtins.rst +++ b/docs/guide/builtins.rst @@ -33,7 +33,7 @@ Built-in Rules - :class:`NoRedundantListComprehension` - :class:`NoStaticIfCondition` - :class:`NoStringTypeAnnotation` -- :class:`ReplaceUnionWithOptional` +- :class:`ReplaceOptionalTypeAnnotation` - :class:`RewriteToComprehension` - :class:`RewriteToLiteral` - :class:`SortedAttributes` @@ -716,45 +716,50 @@ Built-in Rules async def foo() -> Class: return await Class() -.. class:: ReplaceUnionWithOptional +.. class:: ReplaceOptionalTypeAnnotation - Enforces the use of ``Optional[T]`` over ``Union[T, None]`` and ``Union[None, T]``. - See https://docs.python.org/3/library/typing.html#typing.Optional to learn more about Optionals. + Enforces the use of ``T | None`` over ``Optional[T]`` and ``Union[T, None]`` and ``Union[None, T]``. + See https://docs.python.org/3/library/stdtypes.html#types-union. .. attribute:: MESSAGE - `Optional[T]` is preferred over `Union[T, None]` or `Union[None, T]`. Learn more: https://docs.python.org/3/library/typing.html#typing.Optional + `T | None` is preferred over `Optional[T]` or `Union[T, None]` or `Union[None, T]`. Learn more: https://docs.python.org/3/library/stdtypes.html#types-union .. attribute:: AUTOFIX :type: Yes + .. attribute:: PYTHON_VERSION + :type: '>= 3.10' .. attribute:: VALID .. code:: python - def func() -> Optional[str]: + def func() -> str | None: pass .. code:: python - def func() -> Optional[Dict]: + def func() -> Dict | None: pass .. attribute:: INVALID .. code:: python - def func() -> Union[str, None]: + def func() -> Optional[str]: pass + + # suggested fix + def func() -> str | None: + pass + .. code:: python - from typing import Optional def func() -> Union[Dict[str, int], None]: pass # suggested fix - from typing import Optional - def func() -> Optional[Dict[str, int]]: + def func() -> Dict[str, int] | None: pass .. class:: RewriteToComprehension diff --git a/pyproject.toml b/pyproject.toml index 5eb0080c..f3689299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ target-version = ["py38"] [tool.fixit] enable = ["fixit.rules"] -python-version = "3.10" +python-version = "3.8" formatter = "ufmt" [[tool.fixit.overrides]] diff --git a/src/fixit/rules/replace_union_with_optional.py b/src/fixit/rules/replace_optional_type_annotation.py similarity index 51% rename from src/fixit/rules/replace_union_with_optional.py rename to src/fixit/rules/replace_optional_type_annotation.py index d7f2c1b7..8f4a2fa9 100644 --- a/src/fixit/rules/replace_union_with_optional.py +++ b/src/fixit/rules/replace_optional_type_annotation.py @@ -9,27 +9,27 @@ from fixit import Invalid, LintRule, Valid -class ReplaceUnionWithOptional(LintRule): +class ReplaceOptionalTypeAnnotation(LintRule): """ - Enforces the use of ``Optional[T]`` over ``Union[T, None]`` and ``Union[None, T]``. - See https://docs.python.org/3/library/typing.html#typing.Optional to learn more about Optionals. + Enforces the use of ``T | None`` over ``Optional[T]`` and ``Union[T, None]`` and ``Union[None, T]``. + See https://docs.python.org/3/library/stdtypes.html#types-union. """ + PYTHON_VERSION = ">= 3.10" MESSAGE: str = ( - "`Optional[T]` is preferred over `Union[T, None]` or `Union[None, T]`. " - + "Learn more: https://docs.python.org/3/library/typing.html#typing.Optional" + "`T | None` is preferred over `Optional[T]` or `Union[T, None]` or `Union[None, T]`. " + + "Learn more: https://docs.python.org/3/library/stdtypes.html#types-union" ) - METADATA_DEPENDENCIES = (cst.metadata.ScopeProvider,) VALID = [ Valid( """ - def func() -> Optional[str]: + def func() -> str | None: pass """ ), Valid( """ - def func() -> Optional[Dict]: + def func() -> Dict | None: pass """ ), @@ -43,43 +43,41 @@ def func() -> Union[str, int, None]: INVALID = [ Invalid( """ - def func() -> Union[str, None]: + def func() -> Optional[str]: + pass + """, + expected_replacement=""" + def func() -> str | None: pass """, ), Invalid( """ - from typing import Optional def func() -> Union[Dict[str, int], None]: pass """, expected_replacement=""" - from typing import Optional - def func() -> Optional[Dict[str, int]]: + def func() -> Dict[str, int] | None: pass """, ), Invalid( """ - from typing import Optional def func() -> Union[str, None]: pass """, expected_replacement=""" - from typing import Optional - def func() -> Optional[str]: + def func() -> str | None: pass """, ), Invalid( """ - from typing import Optional def func() -> Union[Dict, None]: pass """, expected_replacement=""" - from typing import Optional - def func() -> Optional[Dict]: + def func() -> Dict | None: pass """, ), @@ -87,24 +85,34 @@ def func() -> Optional[Dict]: def leave_Annotation(self, original_node: cst.Annotation) -> None: if self.contains_union_with_none(original_node): - scope = self.get_metadata(cst.metadata.ScopeProvider, original_node, None) nones = 0 indexes = [] replacement = None - if scope is not None and "Optional" in scope: - for s in cst.ensure_type(original_node.annotation, cst.Subscript).slice: - if m.matches(s, m.SubscriptElement(m.Index(m.Name("None")))): - nones += 1 - else: - indexes.append(s.slice) - if not (nones > 1) and len(indexes) == 1: - replacement = original_node.with_changes( - annotation=cst.Subscript( - value=cst.Name("Optional"), - slice=(cst.SubscriptElement(indexes[0]),), - ) + for s in cst.ensure_type(original_node.annotation, cst.Subscript).slice: + if m.matches(s, m.SubscriptElement(m.Index(m.Name("None")))): + nones += 1 + else: + indexes.append(s.slice) + if not (nones > 1) and len(indexes) == 1: + inner_type = cst.ensure_type(indexes[0], cst.Index).value + replacement = original_node.with_changes( + annotation=cst.BinaryOperation( + operator=cst.BitOr(), + left=inner_type, + right=cst.Name("None"), ) - # TODO(T57106602) refactor lint replacement once extract exists + ) + self.report(original_node, replacement=replacement) + elif self.contains_optional(original_node): + subscript_element = cst.ensure_type( + original_node.annotation, cst.Subscript + ).slice[0] + inner_type = cst.ensure_type(subscript_element.slice, cst.Index).value + replacement = original_node.with_changes( + annotation=cst.BinaryOperation( + operator=cst.BitOr(), left=inner_type, right=cst.Name("None") + ) + ) self.report(original_node, replacement=replacement) def contains_union_with_none(self, node: cst.Annotation) -> bool: @@ -126,3 +134,14 @@ def contains_union_with_none(self, node: cst.Annotation) -> bool: ) ), ) + + def contains_optional(self, node: cst.Annotation) -> bool: + return m.matches( + node, + m.Annotation( + m.Subscript( + value=m.Name("Optional"), + slice=[m.SubscriptElement(m.Index())], + ) + ), + )