#!/usr/bin/python3
"""
Runs a squid instance to cache TKLBAM's duplicity archives
Arguments:

    -h | --help         Show this help and exit
    <http_port>         [hostname:]port (e.g., 3128, 127.0.0.1:33128)
    <cachesize>         Size of squid cache (can be absolute or relative size)

                        1000MB - 1000 MBs
                        2GB - 2000 MBs

                        50% - 50% of free space in spool location

Environment variables:

    TKLBAM_SQUID_CACHE_DIR      location of cache dir (default: {0})
    TKLBAM_SQUID_USER           user under which we execute squid (default: {1})
    TKLBAM_SQUID_BIN            path to tklbam squid binary (default: {2})

"""

import os
import pwd
import re
import string
import subprocess
import sys
from tempfile import NamedTemporaryFile

TKLBAM_SQUID_BIN = "/usr/lib/tklbam/deps"
TKLBAM_SQUID_CACHE_DIR = "/var/spool/tklbam-squid"
TKLBAM_SQUID_USER = "proxy"
TKLBAM_SQUID_BIN = "/usr/lib/tklbam/deps/usr/sbin/squid"

TKLBAM_CONF_TPL = r"""
# autogenerated file, do not edit

cache_dir ufs $CACHE_DIR $CACHE_SIZE 16 256
maximum_object_size $MAXIMUM_OBJECT_SIZE MB

# static

http_port $HTTP_PORT
icp_port 0

acl all src all
acl localhost src 127.0.0.1/32

http_access allow localhost
http_access deny all

cache_mem 0 MB
maximum_object_size_in_memory 0 KB

refresh_pattern -i duplicity.*\.gpg$$ 86400 90% 604800 ignore-auth

cache_log /dev/null
access_log none
cache_store_log none
pid_filename none
netdb_filename none
"""


class Error(Exception):
    pass


def get_freespace(path: str) -> int:
    stats = os.statvfs(path)
    freespace = stats.f_bfree * stats.f_bsize
    return freespace


def parse_cache_size(sizestr: str, cache_dir: str) -> int:
    """Parse cache size string and return the size of the cache in megabytes.

    Valid format examples:

    100     100 megabytes

    50%     50% of free space on cache_dir filesystem

    100M    100 megabytes
    100MB   100 megabytes
    100mb   100 megabytes

    10G     10 gigabytes
    10GB    10 gigabytes
    10gb    10 gigabytes
    """

    try:
        return int(sizestr)
    except ValueError:
        pass

    m = re.match(r"^(.*)MB?$", sizestr, re.IGNORECASE)
    if m:
        size = int(m.group(1))
        return (
            size + 16
        )  # why 16 extra MB? make room for the cache structure overhead

    m = re.match(r"^(.*)GB?$", sizestr, re.IGNORECASE)
    if m:
        size = int(m.group(1))
        return size * 1024
    m = re.match(r"^(.*)%$", sizestr, re.IGNORECASE)
    if m:
        percent = int(m.group(1))
        if percent > 100 or percent < 0:
            raise Error(
                f"bad percent {percent}% - should be between 0 and 100"
            )

        return int(
            (get_freespace(cache_dir) * percent / 100.0) / (1024 * 1024)
        )

    raise Error(f"illegal cache_size value '{sizestr}'")


def user_exists(user: str) -> bool | None:
    try:
        pwd.getpwnam(user)
        return True
    except KeyError:
        return False


def is_other_user(user: str) -> bool:
    if pwd.getpwnam(user).pw_uid == os.getuid():
        return False
    else:
        return True


def usage(e: str | None = None) -> None:
    if e:
        print(f"error: {e}", file=sys.stderr)
    print(f"Syntax: {sys.argv[0]} <http_port> <cache_size>", file=sys.stderr)
    print(
        __doc__.strip().format(
            TKLBAM_SQUID_CACHE_DIR, TKLBAM_SQUID_USER, TKLBAM_SQUID_BIN
        ),
        file=sys.stderr,
    )
    sys.exit(1)


def fatal(e: str) -> None:
    print(f"error: {e}", file=sys.stderr)
    sys.exit(1)


def main() -> None:
    args = sys.argv[1:]
    if any(_help in sys.argv for _help in ["-h", "--help"]):
        usage()
    elif len(args) != 2:
        usage("incorrect number of arguments")

    user = os.environ.get("TKLBAM_SQUID_USER", TKLBAM_SQUID_USER)
    if not user_exists(user):
        fatal(f"no such user '{user}'")

    cache_dir = os.environ.get(
        "TKLBAM_SQUID_CACHE_DIR", TKLBAM_SQUID_CACHE_DIR
    )
    if os.path.exists(cache_dir) and not os.path.isdir(cache_dir):
        usage(f"'{cache_dir}' exists but is not a directory")
    else:
        os.makedirs(cache_dir, exist_ok=True)

    os.chmod(cache_dir, 0o750)
    if is_other_user(user):
        pw = pwd.getpwnam(user)
        os.chown(cache_dir, pw.pw_uid, pw.pw_gid)

    http_port = args[0]
    try:
        cache_size = parse_cache_size(args[1], cache_dir)
    except ValueError:
        usage(f"bad cache_size '{args[1]}'")

    conf_tpl = string.Template(TKLBAM_CONF_TPL)

    if cache_size > 2000 and (not sys.maxsize > 2**32):
        print(
            "Detected 32-bit system, limiting maximum_object_size to 2000 MB",
            file=sys.stderr,
        )
        maximum_object_size = 2000
    else:
        maximum_object_size = cache_size

    conf = conf_tpl.substitute(
        HTTP_PORT=http_port,
        CACHE_DIR=cache_dir,
        CACHE_SIZE=cache_size,
        MAXIMUM_OBJECT_SIZE=maximum_object_size,
    )

    # TKLBAM currently uses a custom (legacy) squid build
    #
    # if/when it is replaced with a default squid install, '-D' will need to be
    # dropped as an option. As of Trixie (squid v6.13) '-D' is ignored and is
    # slated for removal in future versions. Disabling DNS needs to be done via
    # ACLs in config file - see: https://wiki.squid-cache.org/SquidFaq/SquidAcl
    squid_bin = os.environ.get("TKLBAM_SQUID_BIN", TKLBAM_SQUID_BIN)
    with NamedTemporaryFile(
        "w", prefix="tklbam-squid-conf-", delete_on_close=False
    ) as tmp_fob:
        os.chmod(tmp_fob.name, 0o644)
        tmp_fob.write(conf)
        tmp_fob.close()
        try:
            # args:
            # -f <conf>     - use <conf> instead of default config file
            # -z            - create swap directories
            subprocess.run([squid_bin, "-f", tmp_fob.name, "-z"], check=True)
            # -N            - no daemon mode
            # -D            - disable initial DNS tests
            # -d <level>    - write debugging to stderr also
            subprocess.run(
                [squid_bin, "-f", tmp_fob.name, "-N", "-D", "-d", "1"],
                check=True,
            )
        except subprocess.CalledProcessError as e:
            raise Error(e)

if __name__ == "__main__":
    main()
