#! /usr/bin/env python

# This file is part of IVRE.
# Copyright 2011 - 2026 Pierre LALET <pierre@droids-corp.org>
#
# IVRE is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# IVRE is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
# License for more details.
#
# You should have received a copy of the GNU General Public License
# along with IVRE. If not, see <http://www.gnu.org/licenses/>.


import argparse
import json
import os
import sys
from glob import glob
from typing import Dict

from ivre.config import RIR_PATH
from ivre.db import DBRir, db
from ivre.utils import CLI_ARGPARSER, range2nets, str2list


def printrec_full(rec: Dict[str, str]) -> None:
    for fld in ["_id", "source_file", "source_hash"]:
        try:
            del rec[fld]
        except KeyError:
            pass
    try:
        start, stop = rec.pop("start"), rec.pop("stop")
    except KeyError:
        asnum = rec.pop("aut-num")
        print(f"aut-num: AS{asnum}")
    else:
        nets = list(range2nets((start, stop)))
        if len(nets) == 1:
            print(f"inetnum: {nets[0]}")
        else:
            print(f"inetnum: {start} - {stop}")
    for k, v in sorted(rec.items()):
        if "\n" in v:
            print(f"{k}:")
            for line in v.split("\n"):
                print(f"    {line}")
        else:
            print(f"{k}: {v}")
    print()


def printrec_json(rec: Dict[str, str]) -> None:
    try:
        del rec["_id"]
    except KeyError:
        pass
    print(json.dumps(rec))


def printrec_short(rec: Dict[str, str]) -> None:
    try:
        start, stop = rec["start"], rec["stop"]
    except KeyError:
        asnum = rec.pop("aut-num")
        obj = f"AS{asnum}"
        info = " - ".join(
            [
                rec["as-name"].replace("\n", " / "),
                rec.get("country", "").replace("\n", " / "),
            ]
        )
    else:
        nets = list(range2nets((start, stop)))
        if len(nets) == 1:
            obj = nets[0]
        else:
            obj = f"{start} - {stop}"
        info = " - ".join(
            [
                rec["netname"].replace("\n", " / "),
                rec.get("country", "").replace("\n", " / "),
            ]
        )
    print(f"{obj}: {info}")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Lookup & manage RIR databases.",
        parents=[CLI_ARGPARSER],
        conflict_handler="resolve",
    )
    if hasattr(db.rir, "searchtext"):  # FIXME, move to DBRir
        parser.add_argument(
            "--search", metavar="FREE TEXT", help="perform a full-text search"
        )
    parser.add_argument(
        "--country", metavar="CODE", help="show only results from this country"
    )
    parser.add_argument(
        "ips",
        nargs="*",
        help="Display results for specified IP addresses.",
    )
    parser.add_argument("--download", action="store_true")
    parser.add_argument("--insert", action="store_true")
    # inherited from CLI_ARGPARSER but meaningless here
    parser.add_argument("--to-db", help=argparse.SUPPRESS)
    parser.add_argument("--http-urls", help=argparse.SUPPRESS)
    parser.add_argument("--http-urls-names", help=argparse.SUPPRESS)
    parser.add_argument("--http-urls-full", help=argparse.SUPPRESS)
    parser.add_argument("--delete", help=argparse.SUPPRESS)
    args = parser.parse_args()
    if args.from_db:
        dbase = DBRir.from_url(args.from_db)
        dbase.globaldb = db
    else:
        dbase = db.rir
    if args.short:
        printrec = printrec_short
    elif args.json:
        printrec = printrec_json
    else:
        printrec = printrec_full
    if args.init:
        if os.isatty(sys.stdin.fileno()):
            sys.stdout.write(
                "This will remove all existing RIR data from your database. Proceed? [y/N] "
            )
            ans = input()
            if ans.lower() != "y":
                sys.exit(-1)
        dbase.init()
        sys.exit(0)
    if args.ensure_indexes:
        if os.isatty(sys.stdin.fileno()):
            sys.stdout.write("This will lock your database. Process ? [y/N] ")
            ans = input()
            if ans.lower() != "y":
                sys.exit(-1)
        dbase.ensure_indexes()
        sys.exit(0)
    if args.sort is None:
        sortkeys = [("start", 1), ("stop", -1)]
    else:
        sortkeys = [
            (field[1:], -1) if field.startswith("~") else (field, 1)
            for field in args.sort
        ]
    if args.update_schema:
        dbase.migrate_schema(args.version)
        sys.exit(0)
    kargs = {}
    if args.limit is not None:
        kargs["limit"] = args.limit
    if args.skip is not None:
        kargs["skip"] = args.skip
    if sortkeys:
        kargs["sort"] = sortkeys
    if args.download or args.insert:
        if args.download:
            filenames = dbase.fetch()
        if args.insert:
            if args.download:
                dbase.import_files(filenames)
            else:
                if RIR_PATH is None:
                    base_path = "."
                else:
                    base_path = RIR_PATH
                dbase.import_files(glob(os.path.join(base_path, "*.db*")))
        elif args.download:
            print("\n".join(sorted(filenames)))
        sys.exit(0)
    flt = dbase.flt_empty
    if args.distinct is not None:
        if hasattr(dbase, "searchtext") and args.search is not None:
            flt = dbase.flt_and(flt, dbase.searchtext(args.search))
        if args.ips:
            flt = dbase.flt_and(
                flt, dbase.flt_or(dbase.searchhost(addr for addr in args.ips))
            )
        for val in dbase.distinct(args.distinct, flt=flt, **kargs):
            print(val)
        sys.exit(0)
    if args.country is not None:
        flt = dbase.flt_and(flt, dbase.searchcountry(str2list(args.country)))
    if hasattr(dbase, "searchtext") and args.search is not None:
        flt = dbase.flt_and(flt, dbase.searchtext(args.search))
        if not args.ips:
            if args.count:
                print(f"{args.search}: {dbase.count(flt)}")
                sys.exit(0)
            if not args.json:
                print(args.search)
                print()
            for res in dbase.get(flt, **kargs):
                printrec(res)
            if not args.json:
                print()
            sys.exit(0)
    if not args.ips and args.count:
        print(dbase.count(flt))
        sys.exit(0)
    # For IP addresses, we only output the "best" (smallest) match, so
    # no limit, skip or sort
    for addr in args.ips:
        if args.count:
            print(f"{addr}: {dbase.count(dbase.flt_and(dbase.searchhost(addr), flt))}")
            continue
        if not args.json:
            print(addr)
        res = dbase.get_best(addr, spec=flt)
        if res is None:
            if not args.json:
                print("UNKNOWN")
        else:
            printrec(res)
        if not (args.short or args.json):
            print()
