#! /usr/bin/env python
# vim:set sts=4 ts=8 sw=4:

"""Program to generate Xrl Target related files"""

from optparse import OptionParser
import os, sys

# This is a bit of mess as this code was split into separate files
import Xif.util

from Xif.util import \
     joining_csv, csv, cpp_name, cpp_classname, xorp_indent_string, \
     xorp_indent, method_name_of_xrl, service_name_of_xrl, \
     xrl_method_name

from Xif.xiftypes import \
     XrlArg, XrlMethod, XrlInterface, XrlTarget

from Xif.parse import \
     XifParser

from Xif.thrifttypes import \
    wire_type, send_arg, recv_arg

# -----------------------------------------------------------------------------
# Target file output related
# -----------------------------------------------------------------------------

def target_declare_service_tables(cls, interfaces):
    s = """
    // Thrift method table for each service in this XRL target
    struct method_entry {
        const char *name;
        const XrlCmdError (%s::*method)(const int32_t);
    };

""" % cls

    for i in interfaces.itervalues():
	s += "    static const struct method_entry %s_methods[];\n" % i.name()

    s += """
    // Thrift service table for this XRL target
    struct service_entry {
        const char *name;
	const struct method_entry * const methods;
    };

    static const struct service_entry services[];
"""
    s += "\n"
    return s

def target_define_method_table(cls, interface):
    s = ""
    s += "const struct %s::method_entry\n" % cls
    s += "%s::%s_methods[] = {\n" % (cls, interface.name())
    for m in interface.methods():
	s += "    { \"%s\",\n" % m.name()
	s += "      &%s::handle_%s_%s },\n" % \
	    (cls, interface.name(), cpp_name(m.name()))
    s += "    { 0, 0 }\n"
    s += "};\n\n"
    return s

# thrift tables are used by per-endpoint dispatch, in place
# of the centralized XrlDispatcher (service dispatch is now local to
# the endpoint and doesn't cross endpoint boundaries).

def target_define_service_tables(cls, interfaces):
    s = ""
    s += "// Thrift method tables\n"
    for i in interfaces.itervalues():
	s += target_define_method_table(cls, i)
    # Define service table containing all service interfaces.
    s += "\n"
    s += "// Thrift service tables\n"
    s += "const struct %s::service_entry\n" % cls
    s += "%s::services[] = {\n" % cls
    for i in interfaces.itervalues():
	s += "    { \"%s\",\n" % i.name()
	s += "      %s::%s_methods },\n" % (cls, i.name())
    s += "    { 0, 0 }\n"
    s += "};\n\n"
    return s

def target_declare_virtual_fns(tgt, interfaces):
    r = ""
    for i in interfaces.itervalues():
	for m in i.methods():
	    # Use fully-qualified XRL-style method name to avoid
	    # rewriting legacy code which uses these RPC stubs.
	    fqm = xrl_method_name(i.name(), i.version(), m.name())
	    r += "    virtual XrlCmdError %s("% cpp_name(fqm)
	    # input args
	    args = []
	    if len(m.args()):
		args.append("\n%s// Input values" % xorp_indent(2))
	    for a in m.args():
		cpa = "\n%sconst %s&\t%s" % \
		      (xorp_indent(2), a.cpp_type(), cpp_name(a.name()))
		args.append(cpa)
	    # output args
	    if len(m.rargs()):
		args.append("\n%s// Output values" % xorp_indent(2))
	    for a in m.rargs():
		cpa = "\n%s%s&\t%s" % \
		      (xorp_indent(2), a.cpp_type(), cpp_name(a.name()))
		args.append(cpa)
	    r += csv(args)
	    r += ") = 0;\n\n"
	# next method in interface
    # next interface in target
    return r

def target_declare_services(interfaces):
    s = ""
    for i in interfaces.itervalues():
	for m in i.methods():
	    s += "    const XrlCmdError handle_%s_%s(" % \
		(i.name(), cpp_name(m.name()))
	    args = [ "const int32_t\trseqid" ]
	    for n in range(0, len(args)):
		args[n] = "\n\t" + args[n]
	    s += csv(args)
	    s += ");\n\n"
    return s;

def target_declare_service_hooks():
    s = ""
    s += "    void add_services();\n"
    s += "    void remove_services();\n"
    return s

def target_define_service_hooks(cls, service_name):
    s =  "void\n%s::add_services()\n{\n" % cls
    s += "    for (const service_entry * sp = &services[0]; sp != 0; sp++) {\n"
    s += "        MethodTable* mt = _st->add_service(sp->name);\n"
    s += "        if (0 == mt) {\n"
    s += "            // raise an exception\n"
    s += "        }\n"
    s += "        for (const method_entry * mp = sp->methods; mp != 0; mp++) {\n"
    s += "            if (! mt->add_method(mp->name,\n"
    s += "                                 callback(this, mp->method))) {\n"
    s += "                // raise an exception\n"
    s += "            }\n"
    s += "        }\n"
    s += "        //_st->finalize();\n"
    s += "    }\n"
    s += "}\n"
    s += "\n"
    s += "void\n%s::remove_services()\n{\n" % cls
    s += "    for (const service_entry * sp = &services[0]; sp != 0; sp++) {\n"
    s += "        MethodTable* mt = _st->method_table(sp->name);\n"
    s += "        for (const method_entry * mp = sp->methods; mp != 0; mp++) {\n"
    s += "            if (! mt->remove_method(mp->name)) {\n"
    s += "                // raise an exception\n"
    s += "            }\n"
    s += "        }\n"
    s += "        if (! _st->remove_service(sp->name)) {\n"
    s += "            // raise an exception\n"
    s += "        }\n"
    s += "    }\n"
    s += "}\n"
    return s;

#
# Generate Thrift input code for XRL target method.
# Assumes message header for a T_CALL was read on inp already.
#
# TODO: Check for duplicate field IDs.
#
def service_method_input(cls, service_name, method):
    s = ""
    arg_checks_enabled = True
    if len(method.args()) == 0:
	s += "    /* This method has no input arguments, skip input. */\n"
	s += "    nin += inp->skip(T_STRUCT);\n"
    else:
	if arg_checks_enabled:
	    s += "#ifndef XIF_DISABLE_TARGET_INPUT_CHECKS\n"
	    s += "    bitset<%d> argf;\n" % len(method.args())
	    s += "#endif\n"
	s += "\n"
	s += "    /* Input value declarations */\n"
	for a in method.args():
	    s += "    %s %s;\n" % (a.cpp_type(), cpp_name(a.name()))
	s += "\n"
	s += "    /* Begin parsing method input struct */\n"
	s += "    string fname;\n"
	s += "    TType ftype;\n"
	s += "    int16_t fid;\n"
	s += "\n"
	s += "    nin += inp->readStructBegin(fname);\n"
	s += "    for (;;) {\n"
	s += "        nin += inp->readFieldBegin(fname, ftype, fid);\n"
	s += "        if (ftype == T_STOP)\n"
	s += "            break;\n"
	s += "        switch (fid) {\n"
	# Parse each field expected in this struct.
	rfid = 0
	for a in method.args():
	    rfid += 1
	    s += "        case %d:\n" % rfid
	    s += "            if (ftype == %s) {\n" % wire_type(a)
	    for l in recv_arg(a):
		s += xorp_indent(4) + l + "\n"
	    if arg_checks_enabled:
		s += "#ifndef XIF_DISABLE_TARGET_INPUT_CHECKS\n"
		s += "                argf.set(%d);\n" % (rfid - 1)
		s += "#endif\n"
	    s += "            }\n"
	    s += "            break;\n"
	# Default is to skip unknown fields.
	s += "        default:\n"
	s += "            nin += inp->skip(ftype);\n"
	s += "            break;\n"
	s += "        }\n"
	s += "        nin += inp->readFieldEnd();\n"
	s += "    }\n"
	s += "    /* End parsing method input struct */\n"
	s += "\n"
	s += "    nin += inp->readStructEnd();\n"
    # end of return argument parsing
    s += "    nin += inp->readMessageEnd();\n"
    s += "    inp->getTransport()->readEnd();\n"
    s += "\n"
    # final argument count check
    if len(method.args()) > 0 and arg_checks_enabled:
	s += "#ifndef XIF_DISABLE_TARGET_INPUT_CHECKS\n"
	s += "    if (argf.count() != %d) {\n" % len(method.args())
	s += "        return XrlCmdError::BAD_ARGS();\n"
	s += "    }\n"
	s += "#endif\n"
	s += "\n"
    return s

def service_method_dispatch(cls, interface, method):
    s = ""
    fqm = xrl_method_name(interface.name(), interface.version(), method.name())
    tab = xorp_indent(1)
    s += tab + "/* Return value declarations */\n"
    for r in method.rargs():
	s += tab + "%s %s;\n" % (r.cpp_type(), cpp_name(r.name()))
    s += tab + "XrlCmdError e = %s(" % cpp_name(fqm)
    params = []
    for a in method.args():
	params.append(a.name())
    for r in method.rargs():
	params.append(r.name())
    s += csv(params, ", ") + ");\n"
    s += "\n"
    s += tab + "if (e != XrlCmdError::OKAY()) {\n"
    s += tab + "    //XLOG_WARNING(\"Handling method for %%s failed: %%s\",\n"
    s += tab + "    //             \"%s\", e.str().c_str());\n" % method.name()
    s += tab + "    return e;\n"
    s += tab + "}\n"
    s += "\n"
    return s

#
# Generate return value generator for Thrifted XIF method.
# These are always wrapped up as a struct, implicitly. We go one
# step further and we always wrap up the result in *another*
# struct if there are any return values.
#
def service_method_output(cls, service_name, method):
    s = ""
    tab = xorp_indent(1)
    mname = method_name_of_xrl(method.name())
    rname = mname + "_result"
    srname = service_name_of_xrl(method.name()) + "_" + rname
    s += tab + "/* Marshall return values */\n"
    s += tab + "nout += outp->writeMessageBegin(\"%s\", T_REPLY, rseqid);\n" % \
	(mname)
    s += tab + "nout += outp->writeStructBegin(\"%s\");\n" % srname
    if len(method.rargs()) > 0:
	s += tab + "nout += outp->writeFieldBegin(\"success\", T_STRUCT, 0);\n"
	# At this point, we start spoofing up a struct. This could in
	# fact be a single field, but we force XIF->Thrift translated IDL
	# to always wrap return values in a real struct.
	s += tab + "nout += outp->writeStructBegin(\"%s\");\n" % rname
	fid = 0
	for r in method.rargs():
	    fid += 1
	    lines = []
	    lines.append("nout += outp->writeFieldBegin(\"%s\", %s, %d);" % \
		(r.name(), wire_type(r), fid))
	    lines += send_arg(r)
	    lines.append("nout += outp->writeFieldEnd();")
	    for l in lines:
		s += tab + l + "\n"
	s += tab + "nout += outp->writeFieldStop();\n"
	s += tab + "nout += outp->writeStructEnd();\n"
    s += tab + "nout += outp->writeFieldStop();\n"
    s += tab + "nout += outp->writeStructEnd();\n"
    s += tab + "nout += outp->writeMessageEnd();\n"
    s += "\n"
    s += tab + "outp->getTransport()->flush();\n"
    s += tab + "outp->getTransport()->writeEnd();\n"
    s += "\n"
    return s

#
# Generate a Thrift-based target method from XIF method definition.
#
def service_method(cls, interface, method):
    s = ""
    s += "const XrlCmdError\n%s::handle_%s_%s(" % \
	    (cls, interface.name(), cpp_name(method.name()))
    args = [ "const int32_t\trseqid" ]
    for i in range(0, len(args)):
        args[i] = "\n\t" + args[i]
    s += csv(args)
    s += "\n)\n"
    s += "{\n"
    s += "    using namespace apache::thrift::protocol;\n"
    s += "\n"
    s += "    TProtocol *inp = 0;\n"
    s += "    TProtocol *outp = 0;\n"
    s += "    uint32_t nin = 0;\n"
    s += "    uint32_t nout = 0;\n"
    s += "\n"
    s += service_method_input(cls, interface.name(), method)
    s += service_method_dispatch(cls, interface, method)
    s += service_method_output(cls, interface.name(), method)
    s += "    return XrlCmdError::OKAY();\n"
    s += "}\n\n"
    return s

def target_define_service_methods(cls, interfaces):
    s = ""
    for i in interfaces.itervalues():
	for m in i.methods():
	    s += service_method(cls, i, m)
    return s

def protect(file):
 # remove direcory component
    r = file.rfind("/") + 1
    return "__XRL_TARGETS_%s__" % file[r:].upper().replace(".", "_")

def prepare_target_hh(modulename, hh_file):
    s   = Xif.util.standard_preamble(1, hh_file)
    s  += \
"""
#ifndef %s
#define %s

#undef XORP_LIBRARY_NAME
#define XORP_LIBRARY_NAME "%s"

//#include "libxorp/xlog.h"

class XrlRouter;
class ServiceTable;

""" % (protect(hh_file), protect(hh_file), modulename)
    return s

def output_target_hh(cls, tgt, interfaces):
    s = """
class %s {
private:
    ServiceTable* _st;
    string _name;

public:
    %s(ServiceTable* st = 0);

    virtual ~%s();

    const string& name() const { return _name; }
    const char* version() const { return "%s/%s"; }

protected:
""" % (cls, cls, cls, tgt.name(), tgt.version())

    s += target_declare_virtual_fns(tgt, interfaces)
    s += "private:\n"
    s += target_declare_services(interfaces)
    s += target_declare_service_hooks()
    s += target_declare_service_tables(cls, interfaces)
    s += "};\n"
    return s

def finish_target_hh(hh_file):
    return "\n#endif // %s\n" % protect(hh_file)

def prepare_target_cc(target_hh, target_cc):
    r = target_hh.rfind("/") + 1
    s = Xif.util.standard_preamble(0, target_cc)
    s  += \
"""

//#include "libxorp/xlog.h"
#include "libxorp/callback.hh"

#include "libxipc/xrl_error.hh"
#include "libxipc/xrl_atom.hh"
#include "libxipc/xrl_atom_list.hh"
#include "libxipc/xrl_router.hh"
#include "libxipc/xrl_sender.hh"    // XXX needed?

#include "libxipc/service_table.hh"

#include <boost/scoped_array.hpp>
#include <boost/shared_ptr.hpp>
#include <boost/static_assert.hpp>

#include <Thrift.h>
#include <transport/TTransport.h>
#include <protocol/TProtocol.h>
#include <protocol/TBinaryProtocol.h>

#include "libxipc/xif_thrift.hh"    // XXX for xif_read_*()

#include <bitset>		// for argument counts

"""
    s += "\n#include \"%s\"\n\n" % target_hh[r:]
    return s

def output_target_cc(cls, tgt, interfaces):
    s = target_define_service_tables(cls, interfaces)
    s += """
// Begin class definitions

%s::%s(ServiceTable* st)
    : _st(st), _name(\"%s\")
{
    if (_st)
	add_services();
}

%s::~%s()
{
    if (_st)
	remove_services();
}

""" % (cls, cls, tgt.name(), cls, cls)
    s += target_define_service_methods(cls, interfaces)
    s += target_define_service_hooks(cls, tgt)

    return s

def main():
    usage = "usage: %prog [options] arg"
    parser = OptionParser(usage)
    parser.add_option("-o", "--output-dir",
		      action="store", 
		      type="string", 
		      dest="output_dir",
		      metavar="DIR")
    parser.add_option("-I",
                      action="append",
                      type="string",
                      dest="includes",
                      metavar="DIR")
    (options,args) = parser.parse_args()

    if len(args) != 1:
        parser_error("incorrect number of arguments")

    # Command line arguments passed on to cpp
    pipe_string = "cpp -C "
    if options.includes:
	for a in options.includes:
	    pipe_string += "-I%s " % a
    pipe_string += args[0] 

    cpp_pipe = os.popen(pipe_string, 'r')

    xp = XifParser(cpp_pipe)

    tgts = xp.targets()
    if len(tgts) == 0:
        print "Not targets found in input files."
        sys.exit(1)

    sourcefile = tgts[0].sourcefile()
    for tgt in tgts:
        if (tgt.sourcefile() != sourcefile):
            print "Multiple .tgt files presented, expected just one."
            sys.exit(1)

    # basename transformation - this is a lame test
    if sourcefile[-4:] != ".tgt":
        print "Source file does not end in .tgt suffix - basename transform failure."
        sys.exit(1)
        
    basename = sourcefile[:-4]
    basename = basename[basename.rfind("/") + 1:]

    modulename = "Xrl%sTarget" % cpp_classname(basename)
    hh_file = "%s_base.hh" % basename
    cc_file = "%s_base.cc" % basename

    if options.output_dir:
        hh_file = os.path.join(options.output_dir, hh_file)
        cc_file = os.path.join(options.output_dir, cc_file)

    hh_txt  = prepare_target_hh(modulename, hh_file)
    cc_txt  = prepare_target_cc(hh_file, cc_file)

    for tgt in xp.targets():
	# Because interfaces are loaded and parsed separately from targets,
	# XrlTarget.interfaces is a list of tuples ("interface_name",
	# "interface_version").
	# Produce a dictionary of interfaces actually referenced by this
	# XRL target, from everything that has been parsed.
	interfaces = dict((i.name(), i) for i in filter(lambda i: i.name() in \
	    map(lambda x: x[0], tgt.interfaces()), xp.interfaces()))
        cls = "Xrl%sTargetBase" % cpp_classname(tgt.name())
        hh_txt += output_target_hh(cls, tgt, interfaces)
        hh_txt += finish_target_hh(hh_file)
        cc_txt += output_target_cc(cls, tgt, interfaces)

    Xif.util.file_write_string(hh_file, hh_txt)
    Xif.util.file_write_string(cc_file, cc_txt)

if __name__ == '__main__':
    main()
