#!/usr/bin/python3

import subprocess
import logging
import tempfile
import os

class SkipTestException(RuntimeError): ...

def arg_string(arg, exception=SkipTestException):
   if not arg:
      raise exception("no argument %s was given." % name)
   return arg

def arg_file(arg, fmt_str="%s", exception=SkipTestException):
   if not arg:
      raise exception("no argument %s was given." % name)
   fn = fmt_str % arg
   if not os.path.exists(fn):
      raise exception("no file %s was found (dervied from %s)." % (fn, arg))
   return fn

def check_call(cmd, **kwargs):
   logging.debug("Executing: '%s'", "' '".join(cmd))
   subprocess.check_call(cmd, **kwargs)

def check_output(cmd):
   logging.debug("Executing: '%s'", "' '".join(cmd))
   output = subprocess.check_output(cmd).decode()
   return output


def dump_trace():
   """This calls dump-trace --stats and compares the result to a pre-recorderd output"""
   global args
   dump_trace     = arg_string(args.dump_trace)
   trace_pb       = arg_file(args.benchmark, "%s/trace.pb")
   trace_pb_stats = arg_file(args.benchmark, "%s/trace.dump-trace-stats")

   with open(trace_pb_stats) as fd:
      trace_pb_stats = fd.read()

   cmd = [dump_trace, '--stats', trace_pb]
   logging.info("Executing: '%s'", "' '".join(cmd))
   stats = subprocess.check_output(cmd).decode()

   if stats != trace_pb_stats:
      logging.error("Mismatch with the prerecorded trace stats file:")
      print(trace_pb_stats, stats)
      return -1


# Database related
def mysql(statement):
   global args
   my_cnf         = arg_file(args.my_cnf)

   cmd = ["mysql", "--defaults-file="+my_cnf, "-e", statement]

   output = check_output(cmd)

   header = None
   for line in output.strip().split("\n"):
      fields = line.split("\t")
      if not header:
         header = fields
      else:
         yield dict(zip(header,fields))

def import_trace():
   global args
   import_trace   = arg_string(args.import_trace)
   trace_pb       = arg_file(args.benchmark, "%s/trace.pb")
   my_cnf         = arg_file(args.my_cnf)
   trace_pb_stats = arg_file(args.benchmark, "%s/trace.dump-trace-stats")
   with open(trace_pb_stats) as fd:
      trace_pb_stats = [x.split(":") for x in fd.read().strip().split("\n")]
      trace_pb_stats = {x[0].strip(): x[1].strip() for x in trace_pb_stats}

   list(mysql("DROP TABLE IF EXISTS variant"))
   list(mysql("DROP TABLE IF EXISTS trace"))

   
   for location, extra_argv in [
         ("memory",   ["-i", "mem"]),
         ("register", ["-i", "regs", "--flags"])]:
      cmd = [import_trace,
             "--database-option-file", my_cnf,
             "-t", trace_pb,
             "-e", arg_file(args.benchmark, "%s/system.elf"),
             "-v", args.benchmark,
             "-b", location,
             ] + extra_argv
      check_call(cmd)

      result = mysql(f"""SELECT v.id as variant_id,
                                min(instr1) as min_instr, 
                                max(instr2) as max_instr, 
                                min(time1) as min_time, 
                                max(time2) as max_time
                        FROM trace t
                        JOIN variant v ON v.id = t.variant_id 
                        WHERE v.variant = "{args.benchmark}" and v.benchmark = "{location}" 
                      """)
      trace_stats = next(result)

      # Check if the number of imported instructions is equal to the result
      # of dump-trace --stats
      trace_instrs = int(trace_stats["max_instr"]) - int(trace_stats['min_instr']) + 1
      stats_instrs = int(trace_pb_stats["#instructions"])
      assert trace_instrs == stats_instrs,\
         f"Number of instructions differs for {args.benchmark}/{location}"

      trace_time = int(trace_stats["max_time"]) - int(trace_stats['min_time']) + 1
      stats_time = int(trace_pb_stats["duration"])
      assert trace_time == stats_time,\
         f"Lenght of imported trace differs for {args.benchmark}/{location}"


      # Infer the number of fault locations
      result = mysql(f"""SELECT DISTINCT data_address, width
                        FROM trace t
                        JOIN variant v ON v.id = t.variant_id 
                        WHERE v.variant = "{args.benchmark}" and v.benchmark = "{location}" 
                      """)
      bits = set()
      for r in result:
         for offset in range(0, int(r['width'])):
            for bit in range(0, 8):
               bits.add((int(r['data_address']) + offset, bit))

      logging.info(f"{len(bits)} fault locations ({location})")


      if location == "memory":
         assert len(bits) == int(trace_pb_stats["#memLocations"]) * 8,\
            "Number of fault locations (memory) is not equal to dump-trace"

      result = mysql(f"""SELECT sum((t.time2 - t.time1 + 1) * width * 8) as fault_space_size
                        FROM trace t
                        JOIN variant v ON v.id = t.variant_id 
                        WHERE v.variant = "{args.benchmark}" and v.benchmark = "{location}" 
                      """)
      fault_space_size = int(next(result)['fault_space_size'])
      assert fault_space_size == len(bits) * trace_time,\
         f"Equivalence Sets do not add up to a square fault space (location={location})"


def basic_pruner():
   # Import Trace has alrady happened
   prune_trace   = arg_string(args.prune_trace)
   my_cnf         = arg_file(args.my_cnf)

   list(mysql("DROP TABLE IF EXISTS fspgroup"))
   list(mysql("DROP TABLE IF EXISTS fsppilot"))

   for variant in mysql("select * from variant"):
      cmd = [prune_trace, "--database-option-file", my_cnf,
             "-v", variant['variant'], "-b", variant['benchmark'],
             "-p", "BasicPruner"
             ]

      check_output(cmd)

      # Is every trace event covered by a fsppilot?
      sql = f"""SELECT p.id is null as no_pilot, count(*) as intervals, sum((t.time2 - t.time1 + 1) * t.width * 8) as area,
                       count(distinct p.id) as pilots
         FROM trace t
         LEFT OUTER JOIN fspgroup g
              on  t.variant_id = g.variant_id 
              and t.instr2 = g.instr2
              and t.data_address = g.data_address
         LEFT OUTER JOIN fsppilot p
              on  t.variant_id = p.variant_id
              and g.fspmethod_id = p.fspmethod_id
              and g.pilot_id = p.id
         WHERE t.variant_id = {variant['id']}
         GROUP by p.id is null
         """
      for result in mysql(sql):
         if result['no_pilot'] != '0':
            assert False, "Found Equivalence intervals without pilots"
         else:
            logging.info(f"{variant['benchmark']}: {result['pilots']} fsppilots, {result['intervals']} intervals, {result['area']}")


      # Check if pruning results match the injection table
      sql = f"""SELECT t.instr2, t.data_address, i1.bitoffset, i1.resulttype as `should`, i2.resulttype as `is`
          FROM trace t
          JOIN injection i1 on t.instr2 = i1.instr2 and t.data_address = i1.data_address
          JOIN fspgroup g
              on  t.variant_id = g.variant_id
              and t.instr2 = g.instr2
              and t.data_address = g.data_address
          JOIN fsppilot p
              on  t.variant_id = p.variant_id
              and g.fspmethod_id = p.fspmethod_id
              and g.pilot_id = p.id
          JOIN injection i2
              on p.instr2 = i2.instr2 and p.data_address = i2.data_address
         WHERE t.variant_id = {variant['id']} and i1.bitoffset = i2.bitoffset and i1.resulttype != i2.resulttype"""

      differ = False
      for row in mysql(sql):
         logger.info("mismatch: %s", row.values)
         differ = True

      assert not differ, "There was a result mismatch for the basic pruner"
      

def import_injection():
   my_cnf    = arg_file(args.my_cnf)
   injection = arg_file(args.benchmark, "%s/injection.sql")
   logging.info("Import injection results")
   with open(injection, "rb", 0) as fd:
      check_call(["mysql", "--defaults-file=" + my_cnf], stdin=fd)


if __name__ == "__main__":
   import argparse
   import sys

   tests = {
      'dump-trace':      [dump_trace],
      'basic-pruner': [import_trace, import_injection, basic_pruner],
   }

   parser = argparse.ArgumentParser("FAIL* Test Driver")

   parser.add_argument("--dump-trace", help="dump-trace binary")
   parser.add_argument("--import-trace", help="import-trace binary")
   parser.add_argument("--prune-trace", help="prune-trace binary")

   parser.add_argument("--my-cnf", help="MYSQL configuration file")

   parser.add_argument("-v", "--verbose", action="store_true", default=False,
                       help="be verbose")

   parser.add_argument("test_mode", help="Which test to execute?",
                       choices=tests.keys())

   parser.add_argument("benchmark", help="Benchmark")

   global args
   args = parser.parse_args()

   if args.verbose:
      logging.basicConfig(level=logging.DEBUG)
   else:
      logging.basicConfig(level=logging.INFO)


   # Dispatch to a given test mode
   try:
      for func in tests[args.test_mode]:
         func()
   except SkipTestException as e:
      print("Skipping Test, because:", e)
      sys.exit(127)

