#! /usr/bin/python3

# Copyright 2014..2022, Martin <debacle@debian.org>
# License: AGPL-3+

# Python standard modules
import argparse
import asyncio
import collections
import configparser
import email.mime.text
import email.utils
import hashlib
import html
import os
import smtplib
import socket
import subprocess
import sys
import textwrap

# additional modules
import apt
import prettytable
import slixmpp

longname = "Pain in the APT"
shortname = "painintheapt"
version = "0.20220226"

columns = ["Name", "Installed", "Candidate"]
Package = collections.namedtuple("Package", " ".join(columns).lower())


def getargs():
    ap = argparse.ArgumentParser(
        description="Pester people about available package updates" + " by email or jabber.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    ap.add_argument(
        "-c",
        "--configfile",
        default="/etc/%s.conf" % shortname,
        help="configuration file",
    )
    ap.add_argument(
        "-d",
        "--debug",
        default=False,
        action="store_true",
        help="print debug output to stderr",
    )
    ap.add_argument(
        "-f",
        "--force",
        default=False,
        action="store_true",
        help="send message, even if updates did not change",
    )
    ap.add_argument(
        "-s",
        "--stampfile",
        help="stamp file",
        default="/var/lib/%s/stamp" % shortname,
    )
    ap.add_argument(
        "-t",
        "--testmessage",
        default=False,
        action="store_true",
        help="send a test message only",
    )
    ap.add_argument("-v", "--version", action="version", version="%(prog)s " + version)
    return ap.parse_args()


def update():
    """Create the APT cache and update it.

    Return the cache and a list of updates.
    """
    updates = []
    cache = apt.Cache()
    cache.update()
    cache.open()
    cache.upgrade(dist_upgrade=True)
    changes = cache.get_changes()
    for c in changes:
        name = c._pkg.name
        pkg = cache[name]
        installed = pkg.installed.version if pkg.installed else "-"
        candidate = pkg.candidate.version if pkg.candidate else "-"
        updates.append(Package(name, installed, candidate))
    return cache, updates


def wrap(text, maxwid):
    """Fill paragraph."""
    return "\n".join(textwrap.wrap(text, maxwid))


_changes = None


def get_changelogs(cache, send_changes):
    """Download changelogs. Beware: This is very slow.

    Identical changelogs for different binary packages are combined.
    """
    global _changes
    if cache is None or send_changes is not True:
        return ""
    if _changes:
        return _changes
    changelogs = collections.defaultdict(list)
    changes = cache.get_changes()
    for c in changes:
        name = c._pkg.name
        changelog = cache[name].get_changelog().strip()
        changelogs[changelog].append(name)
    # now do some very fancy formatting
    maxwid = 79
    _changes = ("\n" + "-" * maxwid + "\n").join(
        sorted(
            [
                wrap(", ".join(sorted(names)), maxwid) + ":\n\n" + changelog
                for changelog, names in changelogs.items()
            ]
        )
    )
    return _changes


def maketable(lst):
    """Create a pretty table of package updates."""
    table = prettytable.PrettyTable(columns)
    table.sortby = columns[0]
    table.align = "l"
    maxwid = 23
    for element in lst:
        table.add_row(
            [
                wrap(element.name, maxwid),
                wrap(element.installed, maxwid),
                wrap(element.candidate, maxwid),
            ]
        )
    return table.get_string()


class JabberBot(slixmpp.ClientXMPP):
    def __init__(
        self,
        jid,
        password,
        to,
        room,
        pubsub_service,
        pubsub_node,
        nick,
        subject,
        table,
        changes,
    ):
        slixmpp.ClientXMPP.__init__(self, jid, password)
        self.to = to
        self.room = room
        self.pubsub_service = pubsub_service
        self.pubsub_node = pubsub_node
        self.nick = nick
        self.add_event_handler("session_start", self.start)
        self.subject = subject
        self.table = table
        self.changes = changes

    def start(self, event):
        self.get_roster()
        self.send_presence()
        pre = "```"
        for to in self.to:
            self.send_message(
                mto=to,
                msubject=self.subject,
                # subject is not shown by all clients, better add it to body
                mbody="\n".join([self.subject, pre, self.table, pre, "\n", self.changes]),
                mtype="chat",
            )
        if self.room:
            self.plugin["xep_0045"].join_muc(self.room, self.nick)
            self.send_message(
                mto=self.room,
                # no per message subject in groupchats, add it to message body
                mbody="\n".join([self.subject, pre, self.table, pre, "\n", self.changes]),
                mtype="groupchat",
            )
        if self.pubsub_service and self.pubsub_node:
            payload = (
                '<entry xmlns="http://www.w3.org/2005/Atom"><title>'
                + html.escape(self.subject)
                + '</title><content type="xhtml"><div>'
                + '<pre xmlns="http://www.w3.org/1999/xhtml">'
                + html.escape(self.table)
                + "</pre><p>"
                + html.escape(self.changes).replace("\n", "</p>\n<p>").replace(" ", "&#160;")
                + "</p></div></content></entry>"
            )
            self["xep_0060"].publish(
                self.pubsub_service,
                self.pubsub_node,
                payload=slixmpp.xmlstream.ET.fromstring(payload),
            )
        self.disconnect(wait=True)


def read_password(config, config_dir):
    password_file = config.get("password_file", "").strip()
    if len(password_file):
        filename = os.path.join(config_dir, password_file)
        with open(filename) as f:
            return f.read().strip()

    print("password deprecated, use password_file instead", file=sys.stderr)
    return config.get("password", "")


def sendxmpp(config, config_dir, table, count, host, debug, changes):
    """Send message to a jabber conference room."""
    jid = config.get("jid", "")
    password = read_password(config, config_dir)
    to = config.get("to", "").split(",")
    room = config.get("room")
    pubsub_service = config.get("pubsub_service", "").strip()
    pubsub_node = config.get("pubsub_node", "").strip()
    subject = "%d package update(s) for %s" % (count, host)
    xmpp = JabberBot(
        jid,
        password,
        to,
        room,
        pubsub_service,
        pubsub_node,
        longname,
        subject,
        table,
        changes,
    )
    xmpp.register_plugin("xep_0030")  # service discovery
    if room:
        xmpp.register_plugin("xep_0045")  # multi-user chat
    if pubsub_service and pubsub_node:
        xmpp.register_plugin("xep_0060")  # pubsub
    xmpp.register_plugin("xep_0199")  # XMPP ping

    xmpp.connect()

    xmpp.loop.run_until_complete(xmpp.disconnected)

    for task in asyncio.all_tasks(loop=xmpp.loop):
        task.cancel()


def sendsmtp(config, config_dir, table, count, host, debug, changes):
    """Send email by SMTP to whomsoever it may concern."""
    server = config.get("server", "localhost")
    port = config.getint("port", 25)
    username = config.get("username", "")
    password = read_password(config, config_dir)
    from_ = config.get("from", username)
    to = config.get("to", username)
    cc = config.get("cc", "")

    msg = email.mime.text.MIMEText("\n\n".join([table, changes]).strip(), "plain", "utf-8")
    msg["From"] = from_
    msg["To"] = to
    msg["Subject"] = "%d package update(s) for %s" % (count, host)
    msg["X-Mailer"] = longname

    if cc:
        msg["Cc"] = cc

    s = smtplib.SMTP(host=server, port=port)
    if debug:
        s.set_debuglevel(True)
    s.starttls()
    s.ehlo_or_helo_if_needed()
    if username or password:
        s.login(username, password)
    recipients = [r[1] for r in email.utils.getaddresses([to + "," + cc])]
    s.sendmail(from_, list(set(recipients)), msg.as_string())
    s.quit()


def sendmailx(config, config_dir, table, count, host, debug, changes):
    """Send email by mailx to whomsoever it may concern."""
    cmd = [
        "/usr/bin/mailx",
        "-r",
        config.get("from", "root"),
        "-s",
        "%d package update(s) for %s" % (count, host),
        "-a",
        "X-Mailer: " + longname,
    ]
    cc = config.get("cc", "")
    if cc:
        cmd += ["-c", cc]
    # this is taken from apticron
    if os.path.realpath("/usr/bin/mailx") == "/usr/bin/heirloom-mailx":
        cmd += ["-S", "ttycharset=utf-8"]
    else:
        cmd += [
            "-a",
            "MIME-Version: 1.0",
            "-a",
            "Content-type: text/plain; charset=UTF-8",
            "-a",
            "Content-transfer-encoding: 8bit",
        ]
    to = config.get("to", "root")
    mailx = subprocess.Popen(cmd + [to], stdin=subprocess.PIPE)
    mailx.stdin.write("\n\n".join([table, changes]).strip())
    mailx.stdin.close()
    mailx.wait()


def has_changed(configfile, table, stampfile):
    change = False
    hashsum = hashlib.sha1()
    for line in open(configfile):
        hashsum.update(line.encode("utf-8"))
    hashsum.update(table.encode("utf-8"))
    newhash = hashsum.hexdigest()
    try:
        with open(stampfile) as f:
            oldhash = f.readline().strip()
    except Exception as err:
        oldhash = "invalid"
    if oldhash != newhash:
        change = True
    return change, newhash


class AcquireProgress(apt.progress.text.AcquireProgress):
    def __init__(self, debug):
        super(AcquireProgress, self).__init__(
            outfile=sys.stderr if debug else open("/dev/null", "w")
        )


if __name__ == "__main__":
    args = getargs()
    config = configparser.ConfigParser()
    config.read(args.configfile)
    config_dir = os.path.dirname(args.configfile)

    fqdn = socket.getfqdn()
    # workaround for dodgy /etc/hosts
    if fqdn in ["localhost", "localhost.localdomain"]:
        fqdn = socket.gethostname() or fqdn

    if args.testmessage:
        cache = None
        count = 0
        table = "this is a test message from painintheapt"
        change = True
    else:
        cache, updates = update()
        count = len(updates)
        table = maketable(updates) if count else ""
        change, newhash = has_changed(args.configfile, table, args.stampfile)

    ret = 0
    for section, function in [
        ("XMPP", sendxmpp),
        ("SMTP", sendsmtp),
        ("MAILX", sendmailx),
    ]:
        try:
            if section in config.sections() and (change or args.force):
                send_changes = config[section].getboolean("send_changes", True)
                function(
                    config[section],
                    config_dir,
                    table,
                    count,
                    fqdn,
                    args.debug,
                    get_changelogs(cache, send_changes),
                )
        except Exception as err:
            print(str(err), file=sys.stderr)
            ret = 1

    if args.testmessage:
        sys.exit(ret)

    if change or args.force:
        with open(args.stampfile, "wb") as f:
            f.write(newhash.encode("utf-8"))

    cache.fetch_archives(progress=AcquireProgress(args.debug))

    sys.exit(ret)
