#!/usr/bin/python
#
# fishpolld - server to run scripts when triggered
#
# Copyright (C) 2008  Owen Taylor
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, If not, see
# http://www.gnu.org/licenses/.

from ConfigParser import RawConfigParser
from optparse import OptionParser
import os
import random
import re
import select
import signal
import socket
import subprocess
from syslog import *
import sys
from StringIO import StringIO
import time
import traceback

FISHPOLL_PORT = 27527
FISHPOLL_LISTEN = '*'
FISHPOLL_PID_FILE = "/var/run/fishpolld.pid"
FISHPOLL_HANDLER_DIR = "/etc/fishpoll.d"

# Used to know whether we should remove the pid file on exit
main_process = True

# We use this instead of syslog when --debug is passed
def debug_log(level, message):
    if level == LOG_ERR:
        level_str = "ERROR:"
    elif level == LOG_WARNING:
        level_str = "WARNING:"
    elif level == LOG_INFO:
        level_str = "INFO:"
    elif level == LOG_DEBUG:
        level_str = "DEBUG:"
    else:
        raise RuntimeException("Bad level")

    print level_str, message

# Read the first line out of a socket (will read and discard some
# data past the first line)
def read_first_line(s, max_bytes, timeout):
    start = time.time()
    bytes_read = 0
    buf = StringIO()
    while True:
        now = time.time()
        to_wait = start + timeout - now
        if to_wait <= 0:
            raise RuntimeError("Timed out")

        if bytes_read >= max_bytes:
            raise RuntimeError("Too much data")

        read_ready, _, _, = select.select([s.fileno()], [], [], to_wait)
        if read_ready:
            bytes = s.recv(max_bytes - bytes_read)
            if (bytes == ""):
                break
            bytes_read += len(bytes)
            buf.write(bytes)

            if bytes.find("\n") >= 0:
                break

    return buf.getvalue()

# Standard double-fork daemonization
def daemonize():
    global main_process

    pid = os.fork()
    if pid > 0:
        main_process = False
        sys.exit(0)

    os.setsid()

    devnullin = os.open("/dev/null", os.O_RDONLY)
    os.close(0)
    os.dup2(devnullin, 0)
    os.close(devnullin)

    devnullout = os.open("/dev/null", os.O_WRONLY)
    os.close(1)
    os.dup2(devnullout, 1)
    os.close(2)
    os.dup2(devnullout, 2)
    os.close(devnullout)

    pid = os.fork()
    if pid > 0:
        main_process = False
        sys.exit(0)

def write_pid_file():
    if options.pid_file:
        try:
            pid_file = open(options.pid_file, "w")
            pid_file.write("%d\n" % os.getpid())
            pid_file.close()
        except IOError, e:
            log(LOG_WARNING, "Cannot write pid to '%s': %s" % (options.pid_file, e.args[1]))

class Topic:
    def __init__(self, server, name, handler_path):
        self.server = server
        self.name = name
        self.handler_path = handler_path
        self.filename = False
        self.pid = None
        self.stdout_fileno = None
        self.stdout_buffer = None
        self.stderr_fileno = None
        self.stderr_buffer = None
        self.pending_subjects = set()
        self.check_on_start = True
        self.check_interval = -1

        conf_path = handler_path + ".conf"
        if os.path.exists(conf_path):
            log(LOG_DEBUG, "Reading configuration from %s" % (conf_path))
            parser = RawConfigParser()
            parser.read(conf_path)
            if not parser.has_section("fishpoll"):
                log(LOG_WARNING, "%s is missing a [fishpoll] section" % (conf_path))
            if parser.has_option("fishpoll", "on_start"):
                try:
                    self.check_on_start = parser.getboolean("fishpoll", "on_start")
                except ValueError, e:
                    log(LOG_WARNING, "on_start: %s" % (e,))
            if parser.has_option("fishpoll", "interval"):
                try:
                    self.check_interval = parser.getfloat("fishpoll", "interval")
                except ValueError, e:
                    log(LOG_WARNING, "interval: %s" % (e,))

            log(LOG_DEBUG, "on_start = %s" % self.check_on_start)
            log(LOG_DEBUG, "interval = %s" % self.check_interval)

        if self.check_interval >= 0:
            self.last_interval = time.time() - self.check_interval * random.random()

    def run_command(self):
        global main_process

        log(LOG_INFO, "Running handler for %s subjects=%s" % (self.name, ",".join(self.pending_subjects)))

        out_pipe = os.pipe()
        err_pipe = os.pipe()

        pid = os.fork()
        if pid > 0:
            self.pending_subjects = set()

            self.stdout_fileno = out_pipe[0]
            os.close(out_pipe[1])
            self.stdout_buffer = StringIO()

            self.stderr_fileno = err_pipe[0]
            os.close(err_pipe[1])
            self.stderr_buffer = StringIO()

            self.pid = pid
        else:
            main_process = False

            self.server.sock.close()

            devnull = os.open("/dev/null", os.O_RDONLY)
            os.close(0)
            os.dup2(devnull, 0)
            os.close(devnull)

            os.close(out_pipe[0])
            os.close(1)
            os.dup2(out_pipe[1], 1)
            os.close(out_pipe[1])

            os.close(err_pipe[0])
            os.close(2)
            os.dup2(err_pipe[1], 2)
            os.close(err_pipe[1])

            pid = os.fork()
            if pid > 0:
                _, status = os.waitpid(pid, 0)
                if status == 0:
                    sys.exit(0)
                else:
                    sys.exit(1)
            else:
                args = [self.name]
                args.extend(sorted(self.pending_subjects))

                env = {}
                env['PATH'] = "/usr/local/bin:/bin:/usr/bin"
                env['TERM'] = "dumb"
                os.execve(self.handler_path, args, env)

    def add_subjects(self, subjects):
        self.pending_subjects.update(subjects)

    def _check_complete(self):
        if self.stderr_fileno != None or self.stdout_fileno != None:
            return

        _, status = os.waitpid(self.pid, 0)
        self.pid = None

        if status == 0:
            log(LOG_INFO, "Fishpoll of %s succeeded" % self.name)
        else:
            stderr = self.stderr_buffer.getvalue()
            if stderr:
                for line in stderr.split("\n"):
                    if line != "":
                        log(LOG_ERR, line)
            log(LOG_ERR, "Fishpoll of %s failed" % self.name)

        if len(self.pending_subjects) > 0:
            self.run_command()

    def read_stdout(self):
        bytes = os.read(self.stdout_fileno, 16384)
        if bytes == "":
            os.close(self.stdout_fileno)
            self.stdout_fileno = None
        else:
            self.stdout_buffer.write(bytes)

        self._check_complete()

    def read_stderr(self):
        bytes = os.read(self.stderr_fileno, 16384)
        if bytes == "":
            os.close(self.stderr_fileno)
            self.stderr_fileno = None
        else:
            self.stderr_buffer.write(bytes)

        self._check_complete()

class Server:
    def __init__(self):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        # This allows reuse again immediately after we exit, even if
        # there still is a onnection still in the TIME_WAIT state
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        m = re.match(r"^(.*?)(?::(\d+))?$", options.listen.strip())
        host = m.group(1)
        if m.group(2) != None:
            port = int(m.group(2))
        else:
            port = FISHPOLL_PORT

        log(LOG_INFO, "Listening on %s:%s" % (host, port))

        if host == '*':
            host = ''

        try:
            self.sock.bind((host, port))
        except socket.error, e:
            print >>sys.stderr, "Can't bind to socket: %s" % e.args[1]
            sys.exit(1)
        self.sock.listen(3) # backlog of 3

        self._read_topics()

    def _read_topics(self):
        self.topics = {}

        try:
            names = os.listdir(options.handler_dir)
        except OSError, e:
            log(LOG_WARNING, "Cannot read handlers from '%s': %s" % (options.handler_dir, e.args[1]))
            return

        for name in os.listdir(options.handler_dir):
            handler_path = os.path.join(options.handler_dir, name)
            if re.match(r".*~$", name):
                continue
            if re.match(r".*\.conf$", name):
                continue
            if re.match(r"^\.", name):
                continue
            if not re.match(r"^[A-Za-z0-9_]+$", name):
                log(LOG_WARNING, "Ignoring '%s', not a valid topic name" % handler_path)
                continue
            if not os.path.isfile(handler_path):
                log(LOG_WARNING, "Ignoring '%s', not a regular file" % handler_path)
                continue
            if not os.access(handler_path, os.X_OK):
                log(LOG_WARNING, "Ignoring '%s', not executable" % handler_path)
                continue

            self.topics[name] = Topic(self, name, handler_path)
            log(LOG_DEBUG, "Adding handler %s for topic %s" % (handler_path, name))

    def read_command(self):
        topic = None

        conn, addr = self.sock.accept()
        try:
            try:
                res = read_first_line(conn, max_bytes=16384, timeout=15)
                newline = res.find("\n")
                if newline >= 0:
                    res = res[0:newline]

                res = res.strip()
                if not re.match("^[A-Za-z0-9_]+(\s+[A-Za-z0-9_.]+)*$", res):
                    raise RuntimeError("Bad command")

                parts = re.split("\s+", res)
                topic_name = parts[0]
                subjects = parts[1:]
                if subjects == []:
                    subjects = ['DEFAULT']

                if not topic_name in self.topics:
                    raise RuntimeError("Unknown topic '%s'" % topic_name)

                topic = self.topics[topic_name]
                topic.add_subjects(subjects)

                log(LOG_INFO, "Received %s subjects=%s (%s)" % (topic_name, ",".join(subjects), addr[0]))
                start = time.time()

                conn.send("OK\n")
            except Exception, e:
                log(LOG_ERR, "Error (%s): %s" % (addr[0], e))

                conn.send("ERROR: " + str(e) + "\n")
        finally:
            conn.close()

        if topic and not topic.pid and len(topic.pending_subjects) > 0:
            topic.run_command()

    def run(self):
        for topic in self.topics.values():
            if topic.check_on_start:
                topic.add_subjects(['DEFAULT'])
                topic.run_command()

        now = time.time()
        while True:
            timeout = None

            read_list = [self.sock.fileno()]
            for topic in self.topics.values():
                if topic.stdout_fileno != None:
                    read_list.append(topic.stdout_fileno)
                if topic.stderr_fileno != None:
                    read_list.append(topic.stderr_fileno)

                if topic.check_interval > 0:
                    topic_timeout = max(0, topic.last_interval + topic.check_interval - now)

                    if timeout == None:
                        timeout = topic_timeout
                    else:
                        timeout = min(timeout, topic_timeout)

            read_ready, _, _, = select.select(read_list, [], [], timeout)
            now = time.time()

            if self.sock.fileno() in read_ready:
                self.read_command()

            for topic in self.topics.values():
                if topic.stdout_fileno != None and topic.stdout_fileno in read_ready:
                    topic.read_stdout()
                if topic.stderr_fileno != None and topic.stderr_fileno in read_ready:
                    topic.read_stderr()
                if topic.check_interval > 0:
                    remaining = max(0, topic.last_interval + topic.check_interval - now)
                    if remaining <= 0:
                        topic.last_interval = now
                        topic.add_subjects(['DEFAULT'])
                        if not topic.pid:
                            topic.run_command()

# Trap sigterm so we have a chance to recover
class TerminatedError(Exception):
    pass

def handle_sigterm(signal, stack):
    raise TerminatedError()

def main():
    global log
    global options

    signal.signal(signal.SIGTERM, handle_sigterm)

    parser = OptionParser()
    parser.add_option("-d", "--debug", action='store_true',
                      help="do not daemonize and log to stdout")
    parser.add_option("", "--pid-file",
                      help="location to write PID of daemon")
    parser.add_option("", "--handler-dir",
                      help="directory to look for handlers in")
    parser.add_option("", "--listen", metavar="HOST[:PORT]",
                      help="address to listen on")

    parser.set_defaults(pid_file=FISHPOLL_PID_FILE,
                        handler_dir=FISHPOLL_HANDLER_DIR,
                        listen=FISHPOLL_LISTEN)

    options, args = parser.parse_args()
    if len(args) > 0:
        parser.print_usage()
        sys.exit(1)

    if options.debug:
        log = debug_log
    else:
        openlog('fishpolld', 0, LOG_DAEMON)
        log = syslog

    # Use a standard environment
    os.chdir("/")
    os.umask(0)

    server = Server()

    try:
        try:
            if not options.debug:
                daemonize()
            write_pid_file()

            log(LOG_INFO, "Starting")
            server.run()
        except SystemExit:
            # Python-2.4 - SystemExit inherits from Exception
            # Pass through so we get the right exit code
            raise
        except KeyboardInterrupt, e:
            log(LOG_INFO, "Exiting on keyboard interrupt")
        except TerminatedError, e:
            log(LOG_INFO, "Exiting on SIGTERM")
        except Exception, e:
            log(LOG_ERR, "Exiting on unexpected exception")
            for line in traceback.format_exc().strip().split("\n"):
                log(LOG_ERR, line)
    finally:
        try:
            if main_process and options.pid_file:
                os.remove(options.pid_file)
        except:
            pass

if __name__ == '__main__':
    main()
