#!/usr/bin/python3 import subprocess import logging import tempfile import os class SkipTestException(RuntimeError): ... def arg_string(arg, exception=SkipTestException): if not arg: raise exception("no argument %s was given." % name) return arg def arg_file(arg, fmt_str="%s", exception=SkipTestException): if not arg: raise exception("no argument %s was given." % name) fn = fmt_str % arg if not os.path.exists(fn): raise exception("no file %s was found (dervied from %s)." % (fn, arg)) return fn def check_call(cmd, **kwargs): logging.debug("Executing: '%s'", "' '".join(cmd)) subprocess.check_call(cmd, **kwargs) def check_output(cmd): logging.debug("Executing: '%s'", "' '".join(cmd)) output = subprocess.check_output(cmd).decode() return output def dump_trace(): """This calls dump-trace --stats and compares the result to a pre-recorderd output""" global args dump_trace = arg_string(args.dump_trace) trace_pb = arg_file(args.benchmark, "%s/trace.pb") trace_pb_stats = arg_file(args.benchmark, "%s/trace.dump-trace-stats") with open(trace_pb_stats) as fd: trace_pb_stats = fd.read() cmd = [dump_trace, '--stats', trace_pb] logging.info("Executing: '%s'", "' '".join(cmd)) stats = subprocess.check_output(cmd).decode() if stats != trace_pb_stats: logging.error("Mismatch with the prerecorded trace stats file:") print(trace_pb_stats, stats) return -1 # Database related def mysql(statement): global args my_cnf = arg_file(args.my_cnf) cmd = ["mysql", "--defaults-file="+my_cnf, "-e", statement] output = check_output(cmd) header = None for line in output.strip().split("\n"): fields = line.split("\t") if not header: header = fields else: yield dict(zip(header,fields)) def import_trace(): global args import_trace = arg_string(args.import_trace) trace_pb = arg_file(args.benchmark, "%s/trace.pb") my_cnf = arg_file(args.my_cnf) trace_pb_stats = arg_file(args.benchmark, "%s/trace.dump-trace-stats") with open(trace_pb_stats) as fd: trace_pb_stats = [x.split(":") for x in fd.read().strip().split("\n")] trace_pb_stats = {x[0].strip(): x[1].strip() for x in trace_pb_stats} list(mysql("DROP TABLE IF EXISTS variant")) list(mysql("DROP TABLE IF EXISTS trace")) for location, extra_argv in [ ("memory", ["-i", "mem"]), ("register", ["-i", "regs", "--flags"])]: cmd = [import_trace, "--database-option-file", my_cnf, "-t", trace_pb, "-e", arg_file(args.benchmark, "%s/system.elf"), "-v", args.benchmark, "-b", location, ] + extra_argv check_call(cmd) result = mysql(f"""SELECT v.id as variant_id, min(instr1) as min_instr, max(instr2) as max_instr, min(time1) as min_time, max(time2) as max_time FROM trace t JOIN variant v ON v.id = t.variant_id WHERE v.variant = "{args.benchmark}" and v.benchmark = "{location}" """) trace_stats = next(result) # Check if the number of imported instructions is equal to the result # of dump-trace --stats trace_instrs = int(trace_stats["max_instr"]) - int(trace_stats['min_instr']) + 1 stats_instrs = int(trace_pb_stats["#instructions"]) assert trace_instrs == stats_instrs,\ f"Number of instructions differs for {args.benchmark}/{location}" trace_time = int(trace_stats["max_time"]) - int(trace_stats['min_time']) + 1 stats_time = int(trace_pb_stats["duration"]) assert trace_time == stats_time,\ f"Lenght of imported trace differs for {args.benchmark}/{location}" # Infer the number of fault locations result = mysql(f"""SELECT DISTINCT data_address, width FROM trace t JOIN variant v ON v.id = t.variant_id WHERE v.variant = "{args.benchmark}" and v.benchmark = "{location}" """) bits = set() for r in result: for offset in range(0, int(r['width'])): for bit in range(0, 8): bits.add((int(r['data_address']) + offset, bit)) logging.info(f"{len(bits)} fault locations ({location})") if location == "memory": assert len(bits) == int(trace_pb_stats["#memLocations"]) * 8,\ "Number of fault locations (memory) is not equal to dump-trace" result = mysql(f"""SELECT sum((t.time2 - t.time1 + 1) * width * 8) as fault_space_size FROM trace t JOIN variant v ON v.id = t.variant_id WHERE v.variant = "{args.benchmark}" and v.benchmark = "{location}" """) fault_space_size = int(next(result)['fault_space_size']) assert fault_space_size == len(bits) * trace_time,\ f"Equivalence Sets do not add up to a square fault space (location={location})" def basic_pruner(): # Import Trace has alrady happened prune_trace = arg_string(args.prune_trace) my_cnf = arg_file(args.my_cnf) list(mysql("DROP TABLE IF EXISTS fspgroup")) list(mysql("DROP TABLE IF EXISTS fsppilot")) for variant in mysql("select * from variant"): cmd = [prune_trace, "--database-option-file", my_cnf, "-v", variant['variant'], "-b", variant['benchmark'], "-p", "BasicPruner" ] check_output(cmd) # Is every trace event covered by a fsppilot? sql = f"""SELECT p.id is null as no_pilot, count(*) as intervals, sum((t.time2 - t.time1 + 1) * t.width * 8) as area, count(distinct p.id) as pilots FROM trace t LEFT OUTER JOIN fspgroup g on t.variant_id = g.variant_id and t.instr2 = g.instr2 and t.data_address = g.data_address LEFT OUTER JOIN fsppilot p on t.variant_id = p.variant_id and g.fspmethod_id = p.fspmethod_id and g.pilot_id = p.id WHERE t.variant_id = {variant['id']} GROUP by p.id is null """ for result in mysql(sql): if result['no_pilot'] != '0': assert False, "Found Equivalence intervals without pilots" else: logging.info(f"{variant['benchmark']}: {result['pilots']} fsppilots, {result['intervals']} intervals, {result['area']}") # Check if pruning results match the injection table sql = f"""SELECT t.instr2, t.data_address, i1.bitoffset, i1.resulttype as `should`, i2.resulttype as `is` FROM trace t JOIN injection i1 on t.instr2 = i1.instr2 and t.data_address = i1.data_address JOIN fspgroup g on t.variant_id = g.variant_id and t.instr2 = g.instr2 and t.data_address = g.data_address JOIN fsppilot p on t.variant_id = p.variant_id and g.fspmethod_id = p.fspmethod_id and g.pilot_id = p.id JOIN injection i2 on p.instr2 = i2.instr2 and p.data_address = i2.data_address WHERE t.variant_id = {variant['id']} and i1.bitoffset = i2.bitoffset and i1.resulttype != i2.resulttype""" differ = False for row in mysql(sql): logger.info("mismatch: %s", row.values) differ = True assert not differ, "There was a result mismatch for the basic pruner" def import_injection(): my_cnf = arg_file(args.my_cnf) injection = arg_file(args.benchmark, "%s/injection.sql") logging.info("Import injection results") with open(injection, "rb", 0) as fd: check_call(["mysql", "--defaults-file=" + my_cnf], stdin=fd) if __name__ == "__main__": import argparse import sys tests = { 'dump-trace': [dump_trace], 'basic-pruner': [import_trace, import_injection, basic_pruner], } parser = argparse.ArgumentParser("FAIL* Test Driver") parser.add_argument("--dump-trace", help="dump-trace binary") parser.add_argument("--import-trace", help="import-trace binary") parser.add_argument("--prune-trace", help="prune-trace binary") parser.add_argument("--my-cnf", help="MYSQL configuration file") parser.add_argument("-v", "--verbose", action="store_true", default=False, help="be verbose") parser.add_argument("test_mode", help="Which test to execute?", choices=tests.keys()) parser.add_argument("benchmark", help="Benchmark") global args args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.INFO) # Dispatch to a given test mode try: for func in tests[args.test_mode]: func() except SkipTestException as e: print("Skipping Test, because:", e) sys.exit(127)