#! /usr/bin/env python

"""\
%prog [-o outfile] <datafile1> <datafile2>

Compare analysis objects between two YODA-readable data files.
"""

import yoda, sys, optparse

parser = optparse.OptionParser(usage=__doc__)
parser.add_option("-o", "--output", default="-", dest="OUTPUT_FILE",
                  help="write output to the given file (default: stdout)")
parser.add_option("-t", "--tol", type=float, default=1e-5, dest="TOL",
                  help="relative tolerance of numerical difference permitted before complaining (default: %default)")
opts, filenames = parser.parse_args()

if len(filenames) != 2:
    print "ERROR! Please supply *two* YODA files for comparison"
    sys.exit(1)

## Get data objects
aodict1 = yoda.read(filenames[0])
aodict2 = yoda.read(filenames[1])

## Check number of data objects in each file
if len(aodict1) != len(aodict2):
    print "Different numbers of data objects in %s and %s" % tuple(filenames[:2])
elif aodict1.keys() != aodict2.keys():
    print "Different data object paths in %s and %s" % tuple(filenames[:2])

## A slightly tolerant numerical comparison function
def eq(a, b):
    if a == b:
        return True
    from math import isnan
    if type(a) is type(b) is float and isnan(a) and isnan(b):
        return True
    ## Type-check: be a bit careful re. int vs. float
    if type(a) is not type(b) and not all(type(x) in (int, float) for x in (a,b)):
        return False
    ## Recursively call on pairs of components if a and b are iterables
    if hasattr(a, "__iter__"):
        return all(eq(*pair) for pair in zip(a, b))
    ## Finally apply a tolerant numerical comparison on numeric types
    # TODO: Be careful with values on either side of zero -- abs(float(a)) etc. on denom?
    return abs(float(a) - float(b))/(float(a) + float(b)) < opts.TOL

def ptstr(pt):
    vstr1 = "{x:.2g} + {ex[0]:.2g} - {ex[1]:.2g}"
    vstr2 = "{x:.2g} +- {ex[0]:.2g}"
    vstrs = []
    if pt.dim >= 1:
        vstrs.append( (vstr2 if eq(*pt.xErrs) else vstr1).format(x=pt.x, ex=pt.xErrs) )
    if pt.dim >= 2:
        vstrs.append( (vstr2 if eq(*pt.yErrs) else vstr1).format(x=pt.y, ex=pt.yErrs) )
    if pt.dim >= 3:
        vstrs.append( (vstr2 if eq(*pt.zErrs) else vstr1).format(x=pt.z, ex=pt.zErrs) )
    return "(" + ", ".join(vstrs) + ")"

## Compare each object pair
for path in sorted(set(aodict1.keys() + aodict2.keys())):
    ## Get the object in file #1
    ao1 = aodict1.get(path, None)
    if ao1 is None:
        print "Data object '%s' not found in %s" % (path, filenames[0])
        continue
    ## Get the object in file #2
    ao2 = aodict2.get(path, None)
    if ao2 is None:
        print "Data object '%s' not found in %s" % (path, filenames[1])
        continue

    ## Compare the file #1 vs. #2 object types
    if type(ao1) is not type(ao2):
        print "Data objects with path '%s' have different types (%s and %s) in %s and %s" % \
              (path, str(type(ao1)), str(type(ao2)), filenames[0], filenames[1])
        continue

    ## Convert to scatter representations
    try:
        s1 = ao1.mkScatter()
        s2 = ao2.mkScatter()
    except Exception, e:
        print "WARNING! Could not create a '%s' scatter for comparison (%s)" % (path, type(e).__name__)

    ## Check for compatible dimensionalities (should already be ok, but just making sure)
    if s1.dim != s2.dim:
        print "Data objects with path '%s' have different scatter dimensions (%d and %d) in %s and %s" % \
              (path, s1.dim, s2.dim, filenames[0], filenames[1])
        continue

    ## Compare the numbers of points/bins
    if s1.numPoints != s2.numPoints:
        print "Data objects with path '%s' have different numbers of points (%d and %d) in %s and %s" % \
              (path, s1.numPoints, s2.numPoints, filenames[0], filenames[1])
        continue

    ## Compare the numeric values of each point
    premsg = "Data points differ for data objects with path '%s' in %s and %s:\n" % (path, filenames[0], filenames[1])
    msgs = []
    for i, (p1, p2) in enumerate(zip(s1.points, s2.points)):
        # TODO: do this more nicely when point.val(int) and point.err(int) are mapped into Python
        ok = True
        if p1.dim >= 1 and not (eq(p1.x, p2.x) and eq(p1.xErrs, p2.xErrs)):
            ok = False
        elif p1.dim >= 2 and not (eq(p1.y, p2.y) and eq(p1.yErrs, p2.yErrs)):
            ok = False
        elif p1.dim >= 3 and not (eq(p1.z, p2.z) and eq(p1.zErrs, p2.zErrs)):
            ok = False
        if not ok:
            msgs.append("  Point #%d: %s vs. %s" % (i, ptstr(p1), ptstr(p2)))
    if msgs:
        print premsg + "\n".join(msgs)
