Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
filipsalo committed Feb 19, 2023
1 parent a75b3e3 commit ac9d062
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
15 changes: 9 additions & 6 deletions surblclient/blacklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
"""Main class for the blacklists"""

import socket
from typing import Literal


def is_ip_address(domain):
def is_ip_address(domain) -> bool:
"""Return True if `domain` is an IP address"""
return all(part.isdigit() for part in domain.split("."))

Expand All @@ -36,14 +37,16 @@ class Blacklist:
domain = ""
flags = []

def __init__(self):
def __init__(self) -> None:
self._cache = (None, None)

def get_base_domain(self, domain):
def get_base_domain(self, domain: str) -> str:
"""Return the base domain to use for RBL lookup"""
return domain

def _lookup_exact(self, domain):
def _lookup_exact(
self, domain: str
) -> tuple[str, list[str]] | Literal[False] | None:
"""Like 'lookup', but checks the exact domain name given.
Not for direct use.
"""
Expand Down Expand Up @@ -74,7 +77,7 @@ def _lookup_exact(self, domain):
return (domain, [s for (n, s) in self.flags if flags & n])
return False

def lookup(self, domain):
def lookup(self, domain: str) -> tuple[str, list[str]] | Literal[False] | None:
"""Extract base domain and check it against SURBL.
Return (basedomain, lists) tuple, where basedomain is the
base domain and lists is a list of strings indicating which
Expand All @@ -94,7 +97,7 @@ def lookup(self, domain):
domain = self.get_base_domain(domain)
return self._lookup_exact(domain)

def __contains__(self, domain):
def __contains__(self, domain: str) -> bool:
"""Return True if base domain is listed in this blacklist;
False otherwise.
"""
Expand Down
6 changes: 3 additions & 3 deletions surblclient/surbl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .blacklist import Blacklist


def domains_from_resource(filename):
def domains_from_resource(filename: str) -> set:
"""Return the domains listen in a data resource file"""
resource_file = importlib.resources.files(__package__) / filename
with resource_file.open("r", encoding="utf-8", errors="strict") as resource_fp:
Expand All @@ -40,15 +40,15 @@ class SURBL(Blacklist):
test_domains = {"surbl.org", "multi.surbl.org"}
flags = [(8, "ph"), (16, "mw"), (64, "abuse"), (128, "cr")]

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._pseudo_tlds = (
domains_from_resource("surbl-two-level-tlds")
| domains_from_resource("surbl-three-level-tlds")
| self.test_domains
)

def get_base_domain(self, domain):
def get_base_domain(self, domain: str) -> str:
while domain.count(".") > 1:
_, _, rest = domain.partition(".")
if rest in self._pseudo_tlds:
Expand Down

0 comments on commit ac9d062

Please sign in to comment.