#!/usr/bin/env python3

#
# asn-banhammer.py - Firewalld ASN ban utility
#
# Copyright (C) 2025 ViciDial Group
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#


import sys
import os
import subprocess
import tempfile
import ipaddress
import xml.etree.ElementTree as ET
import argparse

def fetch_asn_prefixes(asn):
    """Fetch IPv4 prefixes for the given ASN using whois.radb.net."""
    try:
        result = subprocess.run(
            ["whois", "-h", "whois.radb.net", "-i", "origin", asn],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True, text=True
        )
    except subprocess.CalledProcessError as e:
        print(f"[!] Whois query failed: {e.stderr}")
        sys.exit(1)
    # Extract IPv4 CIDRs and per-prefix description
    prefix_descr = []
    current_descr = None
    for line in result.stdout.splitlines():
        if line.lower().startswith('descr:'):
            current_descr = line.split(':', 1)[1].strip()
        elif '/' in line:
            # Try to find a route: line
            parts = line.split()
            for part in parts:
                if '/' in part:
                    try:
                        net = ipaddress.IPv4Network(part, strict=False)
                        prefix_descr.append((str(net), current_descr))
                    except ValueError:
                        continue
    return prefix_descr

def consolidate_prefixes(prefixes):
    """Consolidate overlapping/adjacent CIDRs using ipaddress.collapse_addresses."""
    return list(ipaddress.collapse_addresses(prefixes))

def write_ipset_xml(asn, prefixes):
    """Write the ipset XML file for firewalld."""
    ipset_path = f"/etc/firewalld/ipsets/{asn.lower()}.xml"
    xml_decl = '<?xml version="1.0" encoding="utf-8"?>\n'
    # Use global args for short/description/type/option if present
    ipset_type = getattr(args, 'ipset_type', 'hash:net')
    short = getattr(args, 'short', None)
    desc = getattr(args, 'description', None)
    option = getattr(args, 'option', None)
    with open(ipset_path, "w", encoding="utf-8") as f:
        f.write(xml_decl)
        f.write(f'<ipset type="{ipset_type}">\n')
        if short:
            f.write(f'  <short>{short}</short>\n')
        if desc:
            f.write(f'  <description>{desc}</description>\n')
        if option:
            f.write(f'  <option name="{option[0]}" value="{option[1]}"/>\n')
        for net, descr in prefixes:
            if net:
                if descr:
                    f.write(f'  <entry>{net}</entry> <!-- {descr} -->\n')
                else:
                    f.write(f'  <entry>{net}</entry>\n')
        f.write('</ipset>\n')
    print(f"[*] Wrote {len(prefixes)} prefixes to {ipset_path}")

def main():
    parser = argparse.ArgumentParser(description="Generate firewalld ipset from ASN prefixes.")
    parser.add_argument("ASN", nargs="?", help="Autonomous System Number (e.g. AS132203)")
    parser.add_argument("--zone", default="drop", help="Firewalld zone to add the ipset to (default: drop)")
    parser.add_argument("--short", help="Short name for the ipset (optional)")
    parser.add_argument("--description", help="Description for the ipset (optional)")
    parser.add_argument("--ipset-type", default="hash:net", help="Type for the ipset (default: hash:net)")
    parser.add_argument("--option", nargs=2, metavar=("NAME", "VALUE"), help="Add an <option> element to the ipset (e.g. --option maxelem 262144)")
    global args
    args = parser.parse_args()

    if os.geteuid() != 0:
        print("This script must be run as root.")
        sys.exit(1)

    asn = args.ASN
    zone = args.zone
    print(f"[*] Fetching prefixes for {asn} ...")
    prefix_descr = fetch_asn_prefixes(asn)
    if not prefix_descr:
        print(f"[!] No prefixes found for {asn}. Exiting.")
        sys.exit(1)
    print(f"[*] Consolidating {len(prefix_descr)} prefixes ...")
    # Consolidate by CIDR, but keep description for each
    from collections import defaultdict
    cidr_descr = defaultdict(list)
    for net, descr in prefix_descr:
        cidr_descr[net].append(descr)
    # Collapse addresses, but keep the first description for each
    nets = [ipaddress.IPv4Network(net) for net in cidr_descr.keys()]
    consolidated = list(ipaddress.collapse_addresses(nets))
    consolidated_descr = []
    for net in consolidated:
        net_str = str(net)
        descr = cidr_descr[net_str][0] if cidr_descr[net_str] else None
        consolidated_descr.append((net_str, descr))
    print(f"[*] {len(consolidated_descr)} consolidated prefixes.")
    print(f"ASN: {asn}")
    print(f"Zone: {zone}")
    print(f"Number of prefixes: {len(consolidated_descr)}")
    print("Sample prefixes:")
    for net, descr in consolidated_descr[:10]:
        if descr:
            print(f"  {net}  # {descr}")
        else:
            print(f"  {net}")
    if len(consolidated_descr) > 10:
        print(f"  ... and {len(consolidated_descr)-10} more ...")
    confirm = input("Proceed with writing ipset and updating firewalld? [y/N]: ").strip().lower()
    if confirm != 'y':
        print("Aborted by user.")
        sys.exit(0)
    write_ipset_xml(asn, consolidated_descr)
    print("[*] Reloading firewalld ...")
    subprocess.run(["firewall-cmd", "--reload"], check=True)
    # Add ipset to the specified zone
    ipset_name = asn.lower()
    print(f"[*] Adding ipset {ipset_name} to zone {zone} ...")
    subprocess.run(["firewall-cmd", "--permanent", f"--zone={zone}", f"--add-source=ipset:{ipset_name}"], check=True)
    subprocess.run(["firewall-cmd", "--reload"], check=True)
    print("[*] Done. Verify with:")
    print(f"    firewall-cmd --info-ipset={ipset_name}")
    print(f"    firewall-cmd --zone={zone} --list-sources")

if __name__ == "__main__":
    main()