Skip to content

Commit

Permalink
add code and unit test for detecting missing pydantic dependency and …
Browse files Browse the repository at this point in the history
…output to stdout
  • Loading branch information
yctomwang authored and Kludex committed Mar 29, 2024
1 parent 1296c19 commit 98f0166
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 9 deletions.
18 changes: 18 additions & 0 deletions bump_pydantic/codemods/replace_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from dataclasses import dataclass
from typing import Sequence

import sys
import libcst as cst
import libcst.matchers as m
from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor
from importlib.util import find_spec

IMPORTS = {
"pydantic:BaseSettings": ("pydantic_settings", "BaseSettings"),
Expand All @@ -35,6 +37,13 @@
}


def find_package_install(package_name: str) -> bool:
try:
return find_spec(package_name) is not None
except ModuleNotFoundError:
return False


def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name:
if len(module_parts) == 1:
return m.Name(module_parts[0])
Expand Down Expand Up @@ -98,11 +107,20 @@ class ImportInfo:
class ReplaceImportsCodemod(VisitorBasedCodemodCommand):
@m.leave(IMPORT_MATCH)
def leave_replace_import(self, _: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
to_do_warnings = set()
for import_info in IMPORT_INFOS:
if m.matches(updated_node, import_info.import_from):
aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore
# If multiple objects are imported in a single import statement,
# we need to remove only the one we're replacing.
package_not_installed = not find_package_install(import_info.to_import_str[0])
if package_not_installed:
import_info_part = import_info.to_import_str[0].split('.')[0]
to_do_warning = f" #todo: please install {import_info_part}\n"
if to_do_warning not in to_do_warnings:
sys.stdout.write(to_do_warning)
sys.stdout.flush()
to_do_warnings.add(to_do_warning)
AddImportsVisitor.add_needed_import(self.context, *import_info.to_import_str)
if len(updated_node.names) > 1: # type: ignore
names = [alias for alias in aliases if alias.name.value != import_info.to_import_str[-1]]
Expand Down
91 changes: 82 additions & 9 deletions tests/unit/test_replace_imports.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import pytest
import sys
import io
import importlib
from libcst.codemod import CodemodTest

from contextlib import contextmanager
from bump_pydantic.codemods.replace_imports import ReplaceImportsCodemod


def is_package_installed(package_name):
try:
importlib.import_module(package_name)
return True
except ImportError:
return False


class TestReplaceImportsCommand(CodemodTest):
TRANSFORM = ReplaceImportsCodemod

@contextmanager
def capture_stdout(self):
new_stdout = io.StringIO()
old_stdout = sys.stdout
sys.stdout = new_stdout
yield new_stdout
sys.stdout = old_stdout

def test_base_settings(self) -> None:
before = """
from pydantic import BaseSettings
Expand All @@ -16,11 +35,22 @@ def test_base_settings(self) -> None:
"""
self.assertCodemod(before, after)

with self.capture_stdout() as captured:
self.assertCodemod(before, after)

if is_package_installed("pydantic_settings"):
assert captured.getvalue().strip() == "", "stdout is not empty as expected."
else:
expected_stdout = "#todo: please install pydantic_settings"
assert captured.getvalue().strip() == expected_stdout

def test_noop_base_settings(self) -> None:
code = """
from potato import BaseSettings
"""
self.assertCodemod(code, code)
with self.capture_stdout() as captured:
self.assertCodemod(code, code)
assert captured.getvalue().strip() == "", "stdout is not empty as expected."

@pytest.mark.xfail(reason="To be implemented.")
def test_base_settings_as(self) -> None:
Expand All @@ -39,7 +69,15 @@ def test_color(self) -> None:
after = """
from pydantic_extra_types.color import Color
"""
self.assertCodemod(before, after)

with self.capture_stdout() as captured:
self.assertCodemod(before, after)

if is_package_installed("pydantic_extra_types"):
assert captured.getvalue().strip() == "", "stdout is not empty as expected."
else:
expected_stdout = "#todo: please install pydantic_extra_types"
assert captured.getvalue().strip() == expected_stdout

def test_color_full(self) -> None:
before = """
Expand All @@ -48,13 +86,23 @@ def test_color_full(self) -> None:
after = """
from pydantic_extra_types.color import Color
"""
self.assertCodemod(before, after)
with self.capture_stdout() as captured:
self.assertCodemod(before, after)

if is_package_installed("pydantic_extra_types"):
assert captured.getvalue().strip() == "", "stdout is not empty as expected."
else:
expected_stdout = "#todo: please install pydantic_extra_types"
assert captured.getvalue().strip() == expected_stdout

def test_noop_color(self) -> None:
code = """
from potato import Color
"""
self.assertCodemod(code, code)
with self.capture_stdout() as captured:
self.assertCodemod(code, code)
assert captured.getvalue().strip() == "", "stdout is not empty as expected."

def test_payment_card_number(self) -> None:
before = """
Expand All @@ -63,7 +111,14 @@ def test_payment_card_number(self) -> None:
after = """
from pydantic_extra_types.payment import PaymentCardNumber
"""
self.assertCodemod(before, after)
with self.capture_stdout() as captured:
self.assertCodemod(before, after)

if is_package_installed("pydantic_extra_types"):
assert captured.getvalue().strip() == "", "stdout is not empty as expected."
else:
expected_stdout = "#todo: please install pydantic_extra_types"
assert captured.getvalue().strip() == expected_stdout

def test_payment_card_brand(self) -> None:
before = """
Expand All @@ -72,19 +127,30 @@ def test_payment_card_brand(self) -> None:
after = """
from pydantic_extra_types.payment import PaymentCardBrand
"""
self.assertCodemod(before, after)
with self.capture_stdout() as captured:
self.assertCodemod(before, after)

if is_package_installed("pydantic_extra_types"):
assert captured.getvalue().strip() == "", "stdout is not empty as expected."
else:
expected_stdout = "#todo: please install pydantic_extra_types"
assert captured.getvalue().strip() == expected_stdout

def test_noop_payment_card_number(self) -> None:
code = """
from potato import PaymentCardNumber
"""
self.assertCodemod(code, code)
with self.capture_stdout() as captured:
self.assertCodemod(code, code)
assert captured.getvalue().strip() == "", "stdout is not empty as expected."

def test_noop_payment_card_brand(self) -> None:
code = """
from potato import PaymentCardBrand
"""
self.assertCodemod(code, code)
with self.capture_stdout() as captured:
self.assertCodemod(code, code)
assert captured.getvalue().strip() == "", "stdout is not empty as expected."

def test_both_payment(self) -> None:
before = """
Expand All @@ -93,4 +159,11 @@ def test_both_payment(self) -> None:
after = """
from pydantic_extra_types.payment import PaymentCardBrand, PaymentCardNumber
"""
self.assertCodemod(before, after)
with self.capture_stdout() as captured:
self.assertCodemod(before, after)

if is_package_installed("pydantic_extra_types"):
assert captured.getvalue().strip() == "", "stdout is not empty as expected."
else:
expected_stdout = "#todo: please install pydantic_extra_types"
assert captured.getvalue().strip() == expected_stdout

0 comments on commit 98f0166

Please sign in to comment.