Make pyright happy

This commit is contained in:
László Károlyi 2021-11-16 17:54:49 +01:00
parent c012842b58
commit 18e2b4619e
Signed by: karolyi
GPG Key ID: 2DCAF25E55735BFE
2 changed files with 92 additions and 45 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
*.db
*.yaml
venv/
__pycache__

View File

@ -15,20 +15,25 @@ from pathlib import Path
from smtplib import SMTP
from sqlite3 import Connection, connect
from time import time
from typing import List, Tuple
from typing import List, Tuple, Union
from xml.etree.ElementTree import XML, Element
from zipfile import ZipFile, is_zipfile
from dns.rdatatype import PTR
from dns.resolver import NXDOMAIN, NoAnswer, Timeout, resolve
from dns.reversename import from_address
from yaml import CLoader as Loader
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
from yaml import load
_rowid_getter = attrgetter('rowid')
_INSERTED_FIELDS = (
'checked_at, org_name, domain, header_froms, report_id, unixtime_start, '
'unixtime_end, offending_ip, count, failed_types')
'checked_at, org_name, domain, header_froms, report_id, ' +
'unixtime_start, unixtime_end, offending_ip, count, failed_types')
_FETCHED_FIELDS = f'rowid, {_INSERTED_FIELDS}'
LogRow = namedtuple(typename='LogRow', field_names=_FETCHED_FIELDS)
_CFG_TEMPLATE = """\
@ -56,7 +61,7 @@ report:
subject: DMARC Report analysis output
"""
_DESCRIPTION = (
'DMARC Report analyzer and reporter. '
'DMARC Report analyzer and reporter. ' +
'The default mode is to analyze what\'s in the IMAP folder.')
_STARTTIME = time()
ListOfLogRow = List[LogRow]
@ -86,6 +91,17 @@ class ServerException(DmarcReporterBase):
'Raised when the IMAP server says something is wrong.'
class ElementNotFound(DmarcReporterBase):
'Raised when an XML `Element` is expected to be not `None`.'
class ElementTextNotFound(DmarcReporterBase):
"""
Raised when a path within an XML `Element` is expected to have a text
property.
"""
def _get_sql_connection(config: dict) -> Connection:
'Return an sqlite `Connection` object from the settings.'
return connect(database=Path(config['sqlite_path']).absolute())
@ -102,20 +118,20 @@ def _init_cfg(parsed_args: Namespace):
path_cfg = Path(parsed_args.cfg_file).absolute()
with path_cfg.open('w') as fd:
fd.write(_CFG_TEMPLATE)
print((
f'Config file template written to {path_cfg}. Please edit it '
'before running this program.'))
print(
f'Config file template written to {path_cfg}. Please edit it ' +
'before running this program.')
def _init_db(parsed_args: Namespace, config: dict):
def _init_db(config: dict):
'Initialize a blank DB.'
conn = connect(database=Path(config['sqlite_path']).absolute())
cursor = conn.cursor()
cursor.execute('DROP TABLE IF EXISTS reports')
cursor.execute(
'CREATE TABLE reports (checked_at INTEGER, org_name TEXT, '
'domain TEXT, header_froms TEXT, report_id TEXT, '
'unixtime_start INTEGER, unixtime_end INTEGER, offending_ip TEXT, '
'CREATE TABLE reports (checked_at INTEGER, org_name TEXT, ' +
'domain TEXT, header_froms TEXT, report_id TEXT, ' +
'unixtime_start INTEGER, unixtime_end INTEGER, offending_ip TEXT, ' +
'count INTEGER, failed_types TEXT)')
conn.commit()
conn.close()
@ -126,21 +142,45 @@ def _init_db(parsed_args: Namespace, config: dict):
class XmlParser(object):
'Parse one report here.'
def __init__(self, content: bytes, config: dict, sql_conn: Connection):
def __init__(self, content: bytes, sql_conn: Connection):
self._content = content
self._sql_conn = sql_conn
self._time = datetime
def _get_element_or_fail(self, element: Element, path: str) -> Element:
"""
Return an the text of the element raise `ElementNotFound` if not found.
"""
result = element.find(path=path)
if result is None:
raise ElementNotFound(f'{path} not found within {element}')
return result
def _get_text_or_fail(self, element: Element, path: str) -> str:
"""
Return an the text of the element raise `ElementTextNotFound` if
not found.
"""
element = self._get_element_or_fail(element=element, path=path)
text = element.text
if text is None:
raise ElementTextNotFound(f'{path} within {element} has no text')
return text
def _parse_header(self):
'Parse data headers.'
date_range = self._root.find(path='./report_metadata/date_range')
self._unixtime_start = int(date_range.find(path='begin').text)
self._unixtime_end = int(date_range.find(path='end').text)
self._org_name = self._root.find(
path='./report_metadata/org_name').text
self._domain = self._root.find(path='./policy_published/domain').text
self._report_id = self._root.find(
path='./report_metadata/report_id').text
date_range = self._get_element_or_fail(
element=self._root, path='./report_metadata/date_range')
text_begin = self._get_text_or_fail(element=date_range, path='begin')
self._unixtime_start = int(text_begin)
text_end = self._get_text_or_fail(element=date_range, path='end')
self._unixtime_end = int(text_end)
self._org_name = self._get_text_or_fail(
element=self._root, path='./report_metadata/org_name')
self._domain = self._get_text_or_fail(
element=self._root, path='./policy_published/domain')
self._report_id = self._get_text_or_fail(
element=self._root, path='./report_metadata/report_id')
def _note_failed_records(
self, ip: str, failed: list, count: int, header_froms: list):
@ -148,22 +188,24 @@ class XmlParser(object):
_header_froms = ', '.join(header_froms)
_failed = ', '.join(failed)
with self._sql_conn as conn:
conn.execute((
f'INSERT INTO reports({_INSERTED_FIELDS}) VALUES (?, ?, ?, ?, '
'?, ?, ?, ?, ?, ?)'), (
params = (
_STARTTIME, self._org_name, self._domain, _header_froms,
self._report_id, self._unixtime_start, self._unixtime_end, ip,
count, _failed))
count, _failed)
conn.execute(
f'INSERT INTO reports({_INSERTED_FIELDS}) VALUES ' +
'(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', params)
self._sql_conn.commit()
def _parse_record(self, record: Element):
'Parse one record.'
policy = record.find(path='row/policy_evaluated')
policy = self._get_element_or_fail(
element=record, path='row/policy_evaluated')
failed = [x.tag for x in policy.findall(path='.//*[.="fail"]')]
if not failed:
return
ip = record.find(path='row/source_ip').text
count = int(record.find(path='row/count').text)
ip = self._get_text_or_fail(element=record, path='row/source_ip')
count = int(self._get_text_or_fail(element=record, path='row/count'))
header_froms = [
x.text for x in
record.findall(path='identifiers/header_from')]
@ -185,6 +227,7 @@ class XmlParser(object):
class ImapHandler(object):
'Handling the IMAP connection'
_conn: IMAP4_SSL
_all_messages: List[bytes]
def __init__(self, config: dict):
self._config = config
@ -194,12 +237,14 @@ class ImapHandler(object):
'Extract and return the subject.'
subject = decode_header(header=email['subject'])
for text, encoding in subject:
if encoding is None:
if type(text) is str:
return text
return text.decode(encoding)
return text.decode(encoding=encoding) \
if encoding else text.decode()
return ''
def _get_extracted_gzip_content(self, message: Message) -> Tuple[bytes]:
def _get_extracted_gzip_content(
self, message: Message) -> Tuple[bytes, ...]:
"""
Load and return the extracted XML content of the zip file in the
message.
@ -216,7 +261,8 @@ class ImapHandler(object):
# A finally statement would do here but whatever
return result
def _get_extracted_zip_content(self, message: Message) -> Tuple[bytes]:
def _get_extracted_zip_content(
self, message: Message) -> Tuple[bytes, ...]:
"""
Load and return the extracted XML content of the zip file in the
message.
@ -233,7 +279,7 @@ class ImapHandler(object):
result += (zip_fd.read(),)
return result
def _walk_content(self, message: Message) -> Tuple[bytes]:
def _walk_content(self, message: Message) -> Tuple[bytes, ...]:
'Walk the content of the message recursively.'
result = tuple()
if message.is_multipart():
@ -252,19 +298,21 @@ class ImapHandler(object):
result += self._get_extracted_gzip_content(message=message)
return result
def _parse_message(self, num: bytes) -> Tuple[bytes]:
def _parse_message(self, num: str) -> Union[Tuple[bytes, ...], None]:
'Return the parsed XML content from the parsed message.'
response, msg = self._conn.fetch(
message_set=num, message_parts='(RFC822)')
if response != 'OK' or msg[0] is None or type(msg[0]) is not tuple:
return
message = message_from_bytes(s=msg[0][1])
# subject = self._get_subject(email=message)
extracted_content = self._walk_content(message=message)
# print(subject, extracted_content)
return extracted_content
def _move_processed_messages(self, to_be_moved: list):
def _move_processed_messages(self, to_be_moved: List[bytes]):
'Move processed messages to the designated `Trash`.'
message_set = b','.join(to_be_moved)
message_set = b','.join(to_be_moved).decode()
self._conn.copy(
message_set=message_set,
new_mailbox=self._config['imap']['trash_path'])
@ -279,23 +327,21 @@ class ImapHandler(object):
password=self._config['imap']['password'])
response, result = self._conn.select(
mailbox=self._config['imap']['folder_path'])
if response != 'OK':
if response != 'OK' or result[0] is None:
raise ServerException()
self._no_messages = int(result[0])
response, self._all_messages = self._conn.search(
None, 'ALL') # type: Tuple[str, List[bytes]]
response, self._all_messages = self._conn.search(None, 'ALL')
if response != 'OK':
raise ServerException()
to_be_moved = []
for num in self._all_messages[0].split(): # type: bytes
extracted_content = self._parse_message(num=num)
if not extracted_content:
if not extracted_content or type(num) is not bytes:
continue
to_be_moved.append(num)
for content_item in extracted_content:
parser = XmlParser(
content=content_item, config=self._config,
sql_conn=self._sql_conn)
content=content_item, sql_conn=self._sql_conn)
parser.process()
if to_be_moved:
self._move_processed_messages(to_be_moved=to_be_moved)
@ -359,8 +405,8 @@ class ReportSender(object):
datetime_end = datetime.fromtimestamp(row.unixtime_end)
text += _EMAIL_ITEMS.format(
**row._asdict(),
datetime_start=datetime_start.strftime(format='%c'),
datetime_end=datetime_end.strftime(format='%c'),
datetime_start=datetime_start.strftime('%c'),
datetime_end=datetime_end.strftime('%c'),
hostnames=self._get_hostnames(ip=row.offending_ip))
return text
@ -421,7 +467,7 @@ def main():
return _init_cfg(parsed_args=parsed_args)
config = _get_loaded_cfg(parsed_args=parsed_args)
if parsed_args.init_db:
return _init_db(parsed_args=parsed_args, config=config)
return _init_db(config=config)
if parsed_args.report:
report_sender = ReportSender(config=config)
return report_sender.process()