Split mx_check into dns_check and smtp_check

This will allow us to cleanly and consistently keep the four
verification steps separate: format (regex) - blacklist - dns - smtp.
This commit is contained in:
Reinhard Müller 2021-03-02 18:30:13 +01:00
parent a0f1cd1b04
commit bcbadbab65
6 changed files with 191 additions and 187 deletions

66
tests/test_dns_check.py Normal file
View File

@ -0,0 +1,66 @@
from types import SimpleNamespace
from unittest.case import TestCase
from unittest.mock import Mock, patch
from dns.exception import Timeout
from validate_email import dns_check
from validate_email.exceptions import DNSTimeoutError, NoValidMXError
from validate_email.dns_check import _get_cleaned_mx_records
class DnsNameStub(object):
'Stub for `dns.name.Name`.'
def __init__(self, value: str):
self.value = value
def to_text(self) -> str:
return self.value
TEST_QUERY = Mock()
class GetMxRecordsTestCase(TestCase):
'Testing `_get_mx_records`.'
@patch.object(target=dns_check, attribute='resolve', new=TEST_QUERY)
def test_fails_with_invalid_hostnames(self):
'Fails when an MX hostname is "."'
TEST_QUERY.return_value = [
SimpleNamespace(exchange=DnsNameStub(value='.'))]
with self.assertRaises(NoValidMXError) as exc:
_get_cleaned_mx_records(domain='testdomain1', timeout=10)
self.assertTupleEqual(exc.exception.args, ())
@patch.object(target=dns_check, attribute='resolve', new=TEST_QUERY)
def test_fails_with_null_hostnames(self):
'Fails when an MX hostname is invalid.'
TEST_QUERY.return_value = [
SimpleNamespace(exchange=DnsNameStub(value='asdqwe'))]
with self.assertRaises(NoValidMXError) as exc:
_get_cleaned_mx_records(domain='testdomain2', timeout=10)
self.assertTupleEqual(exc.exception.args, ())
@patch.object(target=dns_check, attribute='resolve', new=TEST_QUERY)
def test_filters_out_invalid_hostnames(self):
'Returns only the valid hostnames.'
TEST_QUERY.return_value = [
SimpleNamespace(exchange=DnsNameStub(value='asdqwe.')),
SimpleNamespace(exchange=DnsNameStub(value='.')),
SimpleNamespace(exchange=DnsNameStub(value='valid.host.')),
# This is an intentional duplicate.
SimpleNamespace(exchange=DnsNameStub(value='valid.host.')),
SimpleNamespace(exchange=DnsNameStub(value='valid2.host.')),
]
result = _get_cleaned_mx_records(domain='testdomain3', timeout=10)
self.assertListEqual(result, ['valid.host', 'valid2.host'])
@patch.object(target=dns_check, attribute='resolve', new=TEST_QUERY)
def test_raises_exception_on_dns_timeout(self):
'Raises exception on DNS timeout.'
TEST_QUERY.side_effect = Timeout()
with self.assertRaises(DNSTimeoutError) as exc:
_get_cleaned_mx_records(domain='testdomain3', timeout=10)
self.assertTupleEqual(exc.exception.args, ())

View File

@ -1,111 +0,0 @@
from smtplib import SMTP
from types import SimpleNamespace
from unittest.case import TestCase
from unittest.mock import Mock, patch
from dns.exception import Timeout
from validate_email import mx_check as mx_module
from validate_email.email_address import EmailAddress
from validate_email.exceptions import (
DNSTimeoutError, NoValidMXError, SMTPCommunicationError, SMTPMessage,
SMTPTemporaryError)
from validate_email.mx_check import (
_get_cleaned_mx_records, _SMTPChecker, mx_check)
class DnsNameStub(object):
'Stub for `dns.name.Name`.'
def __init__(self, value: str):
self.value = value
def to_text(self) -> str:
return self.value
TEST_QUERY = Mock()
class GetMxRecordsTestCase(TestCase):
'Testing `_get_mx_records`.'
@patch.object(target=mx_module, attribute='resolve', new=TEST_QUERY)
def test_fails_with_invalid_hostnames(self):
'Fails when an MX hostname is "."'
TEST_QUERY.return_value = [
SimpleNamespace(exchange=DnsNameStub(value='.'))]
with self.assertRaises(NoValidMXError) as exc:
_get_cleaned_mx_records(domain='testdomain1', timeout=10)
self.assertTupleEqual(exc.exception.args, ())
@patch.object(target=mx_module, attribute='resolve', new=TEST_QUERY)
def test_fails_with_null_hostnames(self):
'Fails when an MX hostname is invalid.'
TEST_QUERY.return_value = [
SimpleNamespace(exchange=DnsNameStub(value='asdqwe'))]
with self.assertRaises(NoValidMXError) as exc:
_get_cleaned_mx_records(domain='testdomain2', timeout=10)
self.assertTupleEqual(exc.exception.args, ())
@patch.object(target=mx_module, attribute='resolve', new=TEST_QUERY)
def test_filters_out_invalid_hostnames(self):
'Returns only the valid hostnames.'
TEST_QUERY.return_value = [
SimpleNamespace(exchange=DnsNameStub(value='asdqwe.')),
SimpleNamespace(exchange=DnsNameStub(value='.')),
SimpleNamespace(exchange=DnsNameStub(value='valid.host.')),
# This is an intentional duplicate.
SimpleNamespace(exchange=DnsNameStub(value='valid.host.')),
SimpleNamespace(exchange=DnsNameStub(value='valid2.host.')),
]
result = _get_cleaned_mx_records(domain='testdomain3', timeout=10)
self.assertListEqual(result, ['valid.host', 'valid2.host'])
@patch.object(target=mx_module, attribute='resolve', new=TEST_QUERY)
def test_raises_exception_on_dns_timeout(self):
'Raises exception on DNS timeout.'
TEST_QUERY.side_effect = Timeout()
with self.assertRaises(DNSTimeoutError) as exc:
_get_cleaned_mx_records(domain='testdomain3', timeout=10)
self.assertTupleEqual(exc.exception.args, ())
@patch.object(target=_SMTPChecker, attribute='check')
def test_skip_smtp_argument(self, check_mx_records_mock):
'Check correct work of `skip_smtp` argument.'
self.assertTrue(mx_check(
EmailAddress('test@mail.ru'), debug=False, skip_smtp=True))
self.assertEqual(check_mx_records_mock.call_count, 0)
check_mx_records_mock.call_count
class SMTPCheckerTest(TestCase):
'Checking the `_SMTPChecker` class methods.'
@patch.object(target=SMTP, attribute='connect')
def test_connect_raises_serverdisconnected(self, mock_connect):
'Connect raises `SMTPServerDisconnected`.'
mock_connect.side_effect = OSError('test message')
checker = _SMTPChecker(
local_hostname='localhost', timeout=5, debug=False,
sender='test@example.com', recip='test@example.com')
with self.assertRaises(SMTPCommunicationError) as exc:
checker.check(hosts=['testhost'])
self.assertDictEqual(exc.exception.error_messages, {
'testhost': SMTPMessage(
command='connect', code=0, text='test message')
})
@patch.object(target=SMTP, attribute='connect')
def test_connect_with_error(self, mock_connect):
'Connect raises `SMTPTemporaryError`.'
checker = _SMTPChecker(
local_hostname='localhost', timeout=5, debug=False,
sender='test@example.com', recip='test@example.com')
mock_connect.return_value = (400, b'test delay message')
with self.assertRaises(SMTPTemporaryError) as exc:
checker.check(hosts=['testhost'])
self.assertDictEqual(exc.exception.error_messages, {
'testhost': SMTPMessage(
command='connect', code=400, text='test delay message')
})

39
tests/test_smtp_check.py Normal file
View File

@ -0,0 +1,39 @@
from smtplib import SMTP
from unittest.case import TestCase
from unittest.mock import patch
from validate_email.exceptions import (
SMTPCommunicationError, SMTPMessage, SMTPTemporaryError)
from validate_email.smtp_check import _SMTPChecker
class SMTPCheckerTest(TestCase):
'Checking the `_SMTPChecker` class methods.'
@patch.object(target=SMTP, attribute='connect')
def test_connect_raises_serverdisconnected(self, mock_connect):
'Connect raises `SMTPServerDisconnected`.'
mock_connect.side_effect = OSError('test message')
checker = _SMTPChecker(
local_hostname='localhost', timeout=5, debug=False,
sender='test@example.com', recip='test@example.com')
with self.assertRaises(SMTPCommunicationError) as exc:
checker.check(hosts=['testhost'])
self.assertDictEqual(exc.exception.error_messages, {
'testhost': SMTPMessage(
command='connect', code=0, text='test message')
})
@patch.object(target=SMTP, attribute='connect')
def test_connect_with_error(self, mock_connect):
'Connect raises `SMTPTemporaryError`.'
checker = _SMTPChecker(
local_hostname='localhost', timeout=5, debug=False,
sender='test@example.com', recip='test@example.com')
mock_connect.return_value = (400, b'test delay message')
with self.assertRaises(SMTPTemporaryError) as exc:
checker.check(hosts=['testhost'])
self.assertDictEqual(exc.exception.error_messages, {
'testhost': SMTPMessage(
command='connect', code=400, text='test delay message')
})

View File

@ -0,0 +1,65 @@
from dns.exception import Timeout
from dns.rdatatype import MX as rdtype_mx
from dns.rdtypes.ANY.MX import MX
from dns.resolver import (
NXDOMAIN, YXDOMAIN, Answer, NoAnswer, NoNameservers, resolve)
from .constants import HOST_REGEX
from .email_address import EmailAddress
from .exceptions import (
DNSConfigurationError, DNSTimeoutError, DomainNotFoundError, NoMXError,
NoNameserverError, NoValidMXError)
def _get_mx_records(domain: str, timeout: int) -> list:
'Return the DNS response for checking, optionally raise exceptions.'
try:
return resolve(
qname=domain, rdtype=rdtype_mx, lifetime=timeout,
search=True) # type: Answer
except NXDOMAIN:
raise DomainNotFoundError
except NoNameservers:
raise NoNameserverError
except Timeout:
raise DNSTimeoutError
except YXDOMAIN:
raise DNSConfigurationError
except NoAnswer:
raise NoMXError
def _get_cleaned_mx_records(domain: str, timeout: int) -> list:
"""
Return a list of hostnames in the MX record, raise an exception on
any issues.
"""
records = _get_mx_records(domain=domain, timeout=timeout)
to_check = list()
host_set = set()
for record in records: # type: MX
dns_str = record.exchange.to_text().rstrip('.') # type: str
if dns_str in host_set:
continue
to_check.append(dns_str)
host_set.add(dns_str)
result = [x for x in to_check if HOST_REGEX.search(string=x)]
if not result:
raise NoValidMXError
return result
def dns_check(email_address: EmailAddress, dns_timeout: int = 10) -> list:
"""
Check whether there are any responsible SMTP servers for the email
address by looking up the DNS MX records.
In case no responsible SMTP servers can be determined, a variety of
exceptions is raised depending on the exact issue, all derived from
`MXError`. Otherwise, return the list of MX hostnames.
"""
if email_address.domain_literal_ip:
return [email_address.domain_literal_ip]
else:
return _get_cleaned_mx_records(
domain=email_address.domain, timeout=dns_timeout)

View File

@ -3,60 +3,14 @@ from smtplib import (
SMTP, SMTPNotSupportedError, SMTPResponseException, SMTPServerDisconnected)
from typing import List, Optional, Tuple
from dns.exception import Timeout
from dns.rdatatype import MX as rdtype_mx
from dns.rdtypes.ANY.MX import MX
from dns.resolver import (
NXDOMAIN, YXDOMAIN, Answer, NoAnswer, NoNameservers, resolve)
from .constants import HOST_REGEX
from .email_address import EmailAddress
from .exceptions import (
AddressNotDeliverableError, DNSConfigurationError, DNSTimeoutError,
DomainNotFoundError, NoMXError, NoNameserverError, NoValidMXError,
SMTPCommunicationError, SMTPMessage, SMTPTemporaryError)
AddressNotDeliverableError, SMTPCommunicationError, SMTPMessage,
SMTPTemporaryError)
LOGGER = getLogger(name=__name__)
def _get_mx_records(domain: str, timeout: int) -> list:
'Return the DNS response for checking, optionally raise exceptions.'
try:
return resolve(
qname=domain, rdtype=rdtype_mx, lifetime=timeout,
search=True) # type: Answer
except NXDOMAIN:
raise DomainNotFoundError
except NoNameservers:
raise NoNameserverError
except Timeout:
raise DNSTimeoutError
except YXDOMAIN:
raise DNSConfigurationError
except NoAnswer:
raise NoMXError
def _get_cleaned_mx_records(domain: str, timeout: int) -> list:
"""
Return a list of hostnames in the MX record, raise an exception on
any issues.
"""
records = _get_mx_records(domain=domain, timeout=timeout)
to_check = list()
host_set = set()
for record in records: # type: MX
dns_str = record.exchange.to_text().rstrip('.') # type: str
if dns_str in host_set:
continue
to_check.append(dns_str)
host_set.add(dns_str)
result = [x for x in to_check if HOST_REGEX.search(string=x)]
if not result:
raise NoValidMXError
return result
class _SMTPChecker(SMTP):
"""
A specialized variant of `smtplib.SMTP` for checking the validity of
@ -209,7 +163,7 @@ class _SMTPChecker(SMTP):
def check(self, hosts: List[str]) -> bool:
"""
Run the check for all given SMTP servers. On positive result,
return `True`, else raise exceptions described in `mx_check`.
return `True`, else raise exceptions described in `smtp_check`.
"""
for host in hosts:
LOGGER.debug(msg=f'Trying {host} ...')
@ -223,41 +177,27 @@ class _SMTPChecker(SMTP):
raise SMTPTemporaryError(error_messages=self.__temporary_errors)
def mx_check(
email_address: EmailAddress, debug: bool,
def smtp_check(
email_address: EmailAddress, mx_records: list, debug: bool,
from_address: Optional[EmailAddress] = None,
helo_host: Optional[str] = None, smtp_timeout: int = 10,
dns_timeout: int = 10, skip_smtp: bool = False) -> bool:
helo_host: Optional[str] = None, smtp_timeout: int = 10) -> bool:
"""
Returns `True` as soon as the any server accepts the recipient
address.
Returns `True` as soon as the any of the given server accepts the
recipient address.
Raise an `AddressNotDeliverableError` if any server unambiguously
and permanently refuses to accept the recipient address.
Raise `SMTPTemporaryError` if the server answers with a temporary
error code when validity of the email address can not be
determined. Greylisting or server delivery issues can be a cause for
this.
error code when validity of the email address can not be determined.
Greylisting or server delivery issues can be a cause for this.
Raise `SMTPCommunicationError` if the SMTP server(s) reply with an
error message to any of the communication steps before the recipient
address is checked, and the validity of the email address can not be
determined either.
In case no responsible SMTP servers can be determined, a variety of
exceptions is raised depending on the exact issue, all derived from
`MXError`.
"""
from_address = from_address or email_address
if email_address.domain_literal_ip:
mx_records = [email_address.domain_literal_ip]
else:
mx_records = _get_cleaned_mx_records(
domain=email_address.domain, timeout=dns_timeout)
if skip_smtp:
return True
smtp_checker = _SMTPChecker(
local_hostname=helo_host, timeout=smtp_timeout, debug=debug,
sender=from_address, recip=email_address)
sender=from_address or email_address, recip=email_address)
return smtp_checker.check(hosts=mx_records)

View File

@ -1,13 +1,14 @@
from logging import getLogger
from typing import Optional
from .dns_check import dns_check
from .domainlist_check import domainlist_check
from .email_address import EmailAddress
from .exceptions import (
AddressFormatError, EmailValidationError, FromAddressFormatError,
SMTPTemporaryError)
from .mx_check import mx_check
from .regex_check import regex_check
from .smtp_check import smtp_check
LOGGER = getLogger(name=__name__)
@ -47,10 +48,14 @@ def validate_email_or_fail(
domainlist_check(address=email_address)
if not check_mx:
return True
return mx_check(
email_address=email_address, from_address=from_address,
helo_host=helo_host, smtp_timeout=smtp_timeout,
dns_timeout=dns_timeout, skip_smtp=skip_smtp, debug=debug)
mx_records = dns_check(
email_address=email_address, dns_timeout=dns_timeout)
if skip_smtp:
return True
return smtp_check(
email_address=email_address, mx_records=mx_records,
from_address=from_address, helo_host=helo_host,
smtp_timeout=smtp_timeout, debug=debug)
def validate_email(email_address: str, *args, **kwargs):