diff --git a/.gitignore b/.gitignore index 2f525b9..d0b39ed 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.db *.yaml venv/ +__pycache__ diff --git a/analyze.py b/analyze.py index 1a48298..34e75ce 100755 --- a/analyze.py +++ b/analyze.py @@ -15,20 +15,25 @@ from pathlib import Path from smtplib import SMTP from sqlite3 import Connection, connect from time import time -from typing import List, Tuple +from typing import List, Tuple, Union from xml.etree.ElementTree import XML, Element from zipfile import ZipFile, is_zipfile from dns.rdatatype import PTR from dns.resolver import NXDOMAIN, NoAnswer, Timeout, resolve from dns.reversename import from_address -from yaml import CLoader as Loader + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader + from yaml import load _rowid_getter = attrgetter('rowid') _INSERTED_FIELDS = ( - 'checked_at, org_name, domain, header_froms, report_id, unixtime_start, ' - 'unixtime_end, offending_ip, count, failed_types') + 'checked_at, org_name, domain, header_froms, report_id, ' + + 'unixtime_start, unixtime_end, offending_ip, count, failed_types') _FETCHED_FIELDS = f'rowid, {_INSERTED_FIELDS}' LogRow = namedtuple(typename='LogRow', field_names=_FETCHED_FIELDS) _CFG_TEMPLATE = """\ @@ -56,7 +61,7 @@ report: subject: DMARC Report analysis output """ _DESCRIPTION = ( - 'DMARC Report analyzer and reporter. ' + 'DMARC Report analyzer and reporter. ' + 'The default mode is to analyze what\'s in the IMAP folder.') _STARTTIME = time() ListOfLogRow = List[LogRow] @@ -86,6 +91,17 @@ class ServerException(DmarcReporterBase): 'Raised when the IMAP server says something is wrong.' +class ElementNotFound(DmarcReporterBase): + 'Raised when an XML `Element` is expected to be not `None`.' + + +class ElementTextNotFound(DmarcReporterBase): + """ + Raised when a path within an XML `Element` is expected to have a text + property. + """ + + def _get_sql_connection(config: dict) -> Connection: 'Return an sqlite `Connection` object from the settings.' return connect(database=Path(config['sqlite_path']).absolute()) @@ -102,20 +118,20 @@ def _init_cfg(parsed_args: Namespace): path_cfg = Path(parsed_args.cfg_file).absolute() with path_cfg.open('w') as fd: fd.write(_CFG_TEMPLATE) - print(( - f'Config file template written to {path_cfg}. Please edit it ' - 'before running this program.')) + print( + f'Config file template written to {path_cfg}. Please edit it ' + + 'before running this program.') -def _init_db(parsed_args: Namespace, config: dict): +def _init_db(config: dict): 'Initialize a blank DB.' conn = connect(database=Path(config['sqlite_path']).absolute()) cursor = conn.cursor() cursor.execute('DROP TABLE IF EXISTS reports') cursor.execute( - 'CREATE TABLE reports (checked_at INTEGER, org_name TEXT, ' - 'domain TEXT, header_froms TEXT, report_id TEXT, ' - 'unixtime_start INTEGER, unixtime_end INTEGER, offending_ip TEXT, ' + 'CREATE TABLE reports (checked_at INTEGER, org_name TEXT, ' + + 'domain TEXT, header_froms TEXT, report_id TEXT, ' + + 'unixtime_start INTEGER, unixtime_end INTEGER, offending_ip TEXT, ' + 'count INTEGER, failed_types TEXT)') conn.commit() conn.close() @@ -126,21 +142,45 @@ def _init_db(parsed_args: Namespace, config: dict): class XmlParser(object): 'Parse one report here.' - def __init__(self, content: bytes, config: dict, sql_conn: Connection): + def __init__(self, content: bytes, sql_conn: Connection): self._content = content self._sql_conn = sql_conn self._time = datetime + def _get_element_or_fail(self, element: Element, path: str) -> Element: + """ + Return an the text of the element raise `ElementNotFound` if not found. + """ + result = element.find(path=path) + if result is None: + raise ElementNotFound(f'{path} not found within {element}') + return result + + def _get_text_or_fail(self, element: Element, path: str) -> str: + """ + Return an the text of the element raise `ElementTextNotFound` if + not found. + """ + element = self._get_element_or_fail(element=element, path=path) + text = element.text + if text is None: + raise ElementTextNotFound(f'{path} within {element} has no text') + return text + def _parse_header(self): 'Parse data headers.' - date_range = self._root.find(path='./report_metadata/date_range') - self._unixtime_start = int(date_range.find(path='begin').text) - self._unixtime_end = int(date_range.find(path='end').text) - self._org_name = self._root.find( - path='./report_metadata/org_name').text - self._domain = self._root.find(path='./policy_published/domain').text - self._report_id = self._root.find( - path='./report_metadata/report_id').text + date_range = self._get_element_or_fail( + element=self._root, path='./report_metadata/date_range') + text_begin = self._get_text_or_fail(element=date_range, path='begin') + self._unixtime_start = int(text_begin) + text_end = self._get_text_or_fail(element=date_range, path='end') + self._unixtime_end = int(text_end) + self._org_name = self._get_text_or_fail( + element=self._root, path='./report_metadata/org_name') + self._domain = self._get_text_or_fail( + element=self._root, path='./policy_published/domain') + self._report_id = self._get_text_or_fail( + element=self._root, path='./report_metadata/report_id') def _note_failed_records( self, ip: str, failed: list, count: int, header_froms: list): @@ -148,22 +188,24 @@ class XmlParser(object): _header_froms = ', '.join(header_froms) _failed = ', '.join(failed) with self._sql_conn as conn: - conn.execute(( - f'INSERT INTO reports({_INSERTED_FIELDS}) VALUES (?, ?, ?, ?, ' - '?, ?, ?, ?, ?, ?)'), ( + params = ( _STARTTIME, self._org_name, self._domain, _header_froms, self._report_id, self._unixtime_start, self._unixtime_end, ip, - count, _failed)) + count, _failed) + conn.execute( + f'INSERT INTO reports({_INSERTED_FIELDS}) VALUES ' + + '(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', params) self._sql_conn.commit() def _parse_record(self, record: Element): 'Parse one record.' - policy = record.find(path='row/policy_evaluated') + policy = self._get_element_or_fail( + element=record, path='row/policy_evaluated') failed = [x.tag for x in policy.findall(path='.//*[.="fail"]')] if not failed: return - ip = record.find(path='row/source_ip').text - count = int(record.find(path='row/count').text) + ip = self._get_text_or_fail(element=record, path='row/source_ip') + count = int(self._get_text_or_fail(element=record, path='row/count')) header_froms = [ x.text for x in record.findall(path='identifiers/header_from')] @@ -185,6 +227,7 @@ class XmlParser(object): class ImapHandler(object): 'Handling the IMAP connection' _conn: IMAP4_SSL + _all_messages: List[bytes] def __init__(self, config: dict): self._config = config @@ -194,12 +237,14 @@ class ImapHandler(object): 'Extract and return the subject.' subject = decode_header(header=email['subject']) for text, encoding in subject: - if encoding is None: + if type(text) is str: return text - return text.decode(encoding) + return text.decode(encoding=encoding) \ + if encoding else text.decode() return '' - def _get_extracted_gzip_content(self, message: Message) -> Tuple[bytes]: + def _get_extracted_gzip_content( + self, message: Message) -> Tuple[bytes, ...]: """ Load and return the extracted XML content of the zip file in the message. @@ -216,7 +261,8 @@ class ImapHandler(object): # A finally statement would do here but whatever return result - def _get_extracted_zip_content(self, message: Message) -> Tuple[bytes]: + def _get_extracted_zip_content( + self, message: Message) -> Tuple[bytes, ...]: """ Load and return the extracted XML content of the zip file in the message. @@ -233,7 +279,7 @@ class ImapHandler(object): result += (zip_fd.read(),) return result - def _walk_content(self, message: Message) -> Tuple[bytes]: + def _walk_content(self, message: Message) -> Tuple[bytes, ...]: 'Walk the content of the message recursively.' result = tuple() if message.is_multipart(): @@ -252,19 +298,21 @@ class ImapHandler(object): result += self._get_extracted_gzip_content(message=message) return result - def _parse_message(self, num: bytes) -> Tuple[bytes]: + def _parse_message(self, num: str) -> Union[Tuple[bytes, ...], None]: 'Return the parsed XML content from the parsed message.' response, msg = self._conn.fetch( message_set=num, message_parts='(RFC822)') + if response != 'OK' or msg[0] is None or type(msg[0]) is not tuple: + return message = message_from_bytes(s=msg[0][1]) # subject = self._get_subject(email=message) extracted_content = self._walk_content(message=message) # print(subject, extracted_content) return extracted_content - def _move_processed_messages(self, to_be_moved: list): + def _move_processed_messages(self, to_be_moved: List[bytes]): 'Move processed messages to the designated `Trash`.' - message_set = b','.join(to_be_moved) + message_set = b','.join(to_be_moved).decode() self._conn.copy( message_set=message_set, new_mailbox=self._config['imap']['trash_path']) @@ -279,23 +327,21 @@ class ImapHandler(object): password=self._config['imap']['password']) response, result = self._conn.select( mailbox=self._config['imap']['folder_path']) - if response != 'OK': + if response != 'OK' or result[0] is None: raise ServerException() self._no_messages = int(result[0]) - response, self._all_messages = self._conn.search( - None, 'ALL') # type: Tuple[str, List[bytes]] + response, self._all_messages = self._conn.search(None, 'ALL') if response != 'OK': raise ServerException() to_be_moved = [] for num in self._all_messages[0].split(): # type: bytes extracted_content = self._parse_message(num=num) - if not extracted_content: + if not extracted_content or type(num) is not bytes: continue to_be_moved.append(num) for content_item in extracted_content: parser = XmlParser( - content=content_item, config=self._config, - sql_conn=self._sql_conn) + content=content_item, sql_conn=self._sql_conn) parser.process() if to_be_moved: self._move_processed_messages(to_be_moved=to_be_moved) @@ -359,8 +405,8 @@ class ReportSender(object): datetime_end = datetime.fromtimestamp(row.unixtime_end) text += _EMAIL_ITEMS.format( **row._asdict(), - datetime_start=datetime_start.strftime(format='%c'), - datetime_end=datetime_end.strftime(format='%c'), + datetime_start=datetime_start.strftime('%c'), + datetime_end=datetime_end.strftime('%c'), hostnames=self._get_hostnames(ip=row.offending_ip)) return text @@ -421,7 +467,7 @@ def main(): return _init_cfg(parsed_args=parsed_args) config = _get_loaded_cfg(parsed_args=parsed_args) if parsed_args.init_db: - return _init_db(parsed_args=parsed_args, config=config) + return _init_db(config=config) if parsed_args.report: report_sender = ReportSender(config=config) return report_sender.process()