diff --git a/src/core/util/SumTree.hpp b/src/core/util/SumTree.hpp index 998e3e05..d5df54da 100644 --- a/src/core/util/SumTree.hpp +++ b/src/core/util/SumTree.hpp @@ -4,10 +4,11 @@ #include #include #include +#include // The SumTree implements an efficient tree data structure for // "roulette-wheel" sampling, or "sampling with fault expansion", i.e., -// sampling of trace entries / pilots without replacement and with a +// sampling of trace entries / pilots with/without replacement and with a // picking probability proportional to the entries' sizes. // // For every sample, the naive approach picks a random number between 0 @@ -24,9 +25,9 @@ // // Note that the current implementation is built for a pure growth phase // (when the tree gets filled with pilots from the database), followed by -// a sampling phase when the tree gets emptied. It does not handle a -// mixed add/remove case very smartly, although it should remain -// functional. +// a sampling phase when the tree gets sampled from (with replacement) or +// emptied (without replacement). It does not handle a mixed add/remove case +// very smartly, although it should remain functional. namespace fail { @@ -44,6 +45,64 @@ class SumTree { std::vector elements; }; +public: + //! Iterator + class TreeIterator : public std::iterator { + //! Buckets and corresponding element indexes down the tree + std::stack > hierarchy; + public: + TreeIterator() {} +//MyIterator(int* x) :p(x) {} + TreeIterator(const TreeIterator& i) : hierarchy(i.hierarchy) { } + TreeIterator(const SumTree& tree) + { + // go down until we see leaves + hierarchy.push(std::pair(tree.m_root, 0)); + while (!hierarchy.top().first->elements.size() && hierarchy.top().first->children.size() > 0) { + hierarchy.push(std::pair(hierarchy.top().first->children[hierarchy.top().second], 0)); + } + } + TreeIterator& operator++() + { + // advance index in the current level + hierarchy.top().second++; + if (hierarchy.top().second < hierarchy.top().first->elements.size()) { + return *this; + } + // current level is exhausted, go back up to a not yet finished level + do { + hierarchy.pop(); + } while (!hierarchy.empty() + && ++hierarchy.top().second >= hierarchy.top().first->children.size()); + // at the end? + if (hierarchy.empty()) { + return *this; + } + // go down until we see leaves again + do { + hierarchy.push(std::pair(hierarchy.top().first->children[hierarchy.top().second], 0)); + } while (!hierarchy.top().first->elements.size() && hierarchy.top().first->children.size() > 0); + return *this; + } + TreeIterator operator++(int) { TreeIterator tmp(*this); operator++(); return tmp; } + bool operator==(const TreeIterator& rhs) { return hierarchy == rhs.hierarchy; } + bool operator!=(const TreeIterator& rhs) { return hierarchy != rhs.hierarchy; } + T& operator*() { return hierarchy.top().first->elements[hierarchy.top().second]; } + T *operator->() { return &(operator*()); } + }; + typedef TreeIterator iterator; + + iterator begin() + { + return iterator(*this); + } + + iterator end() + { + return iterator(); + } + +private: //! Root node Bucket *m_root; //! Tree depth: nodes at level m_depth are leaf nodes, others are inner nodes @@ -51,17 +110,21 @@ class SumTree { public: SumTree() : m_root(new Bucket), m_depth(0) {} ~SumTree() { delete m_root; } - //! Adds a new element to the tree. + //! Adds a copy of a new element to the tree. The copy is created internally. void add(const T& element); - //! Retrieves (and removes) element at random number position. - T get(typename T::size_type pos) { return get(pos, m_root, 0); } + //! Retrieves and removes element at random number position. + T remove(typename T::size_type pos) { return remove(pos, m_root, 0); } + //! Retrieves reference to element at random number position. + T& get(typename T::size_type pos) { return get(pos, m_root, 0); } //! Yields the sum over all elements in the tree. typename T::size_type get_size() const { return m_root->size; } private: //! Internal, recursive version of add(). bool add(Bucket **node, const T& element, unsigned depth_remaining); + //! Internal, recursive version of remove(). + T remove(typename T::size_type pos, Bucket *node, typename T::size_type sum); //! Internal, recursive version of get(). - T get(typename T::size_type pos, Bucket *node, typename T::size_type sum); + T& get(typename T::size_type pos, Bucket *node, typename T::size_type sum); }; // template implementation @@ -137,7 +200,7 @@ bool SumTree::add(Bucket **node, const T& element, unsigned depth } template -T SumTree::get(typename T::size_type pos, Bucket *node, typename T::size_type sum) +T SumTree::remove(typename T::size_type pos, Bucket *node, typename T::size_type sum) { // sanity check assert(pos >= sum && pos < sum + node->size); @@ -153,7 +216,7 @@ T SumTree::get(typename T::size_type pos, Bucket *node, typename // found containing bucket, recurse sum -= (*it)->size; - T e = get(pos, *it, sum); + T e = remove(pos, *it, sum); node->size -= e.size(); // remove empty (or, at least, zero-sized) child? if ((*it)->size == 0) { @@ -184,6 +247,44 @@ T SumTree::get(typename T::size_type pos, Bucket *node, typename return T(); } +template +T& SumTree::get(typename T::size_type pos, Bucket *node, typename T::size_type sum) +{ + // sanity check + assert(pos >= sum && pos < sum + node->size); + + // will only be entered for inner nodes + for (typename std::vector::iterator it = node->children.begin(); + it != node->children.end(); ) { + sum += (*it)->size; + if (sum <= pos) { + ++it; + continue; + } + + // found containing bucket, recurse + sum -= (*it)->size; + return get(pos, *it, sum); + } + + // will only be entered for leaf nodes + for (typename std::vector::iterator it = node->elements.begin(); + it != node->elements.end(); ) { + sum += it->size(); + if (sum <= pos) { + ++it; + continue; + } + + // found pilot + return *it; + } + + // this should never happen + assert(0); + return *(new T); +} + } // namespace #endif diff --git a/src/core/util/testing/SumTreeTest.cc b/src/core/util/testing/SumTreeTest.cc index 61153435..e4880632 100644 --- a/src/core/util/testing/SumTreeTest.cc +++ b/src/core/util/testing/SumTreeTest.cc @@ -19,18 +19,24 @@ struct Pilot { int main() { - fail::SumTree tree; + typedef fail::SumTree sumtree_type; + sumtree_type tree; for (int i = 0; i <= 20; ++i) { Pilot p; p.duration = i; tree.add(p); } + LOG << "tree contents:" << endl; + for (sumtree_type::iterator it = tree.begin(); it != tree.end(); ++it) { + LOG << it->size() << endl; + } + while (tree.get_size() > 0) { uint64_t pos = tree.get_size() / 2; LOG << "MAIN tree.get_size() = " << tree.get_size() << ", trying to retrieve pos = " << pos << endl; - Pilot p = tree.get(pos); + Pilot p = tree.remove(pos); LOG << "MAIN retrieved pilot with duration " << p.duration << endl; } } diff --git a/tools/prune-trace/CMakeLists.txt b/tools/prune-trace/CMakeLists.txt index 4043ffd4..aadf6545 100644 --- a/tools/prune-trace/CMakeLists.txt +++ b/tools/prune-trace/CMakeLists.txt @@ -2,6 +2,7 @@ set(SRCS Pruner.cc BasicPruner.cc FESamplingPruner.cc + SamplingPruner.cc ) find_package(MySQL REQUIRED) diff --git a/tools/prune-trace/FESamplingPruner.cc b/tools/prune-trace/FESamplingPruner.cc index fb4f925b..e8f0710f 100644 --- a/tools/prune-trace/FESamplingPruner.cc +++ b/tools/prune-trace/FESamplingPruner.cc @@ -14,7 +14,10 @@ struct Pilot { uint64_t duration; uint32_t instr2; + union { uint32_t instr2_absolute; + uint32_t id; + }; uint32_t data_address; typedef uint64_t size_type; @@ -26,6 +29,12 @@ bool FESamplingPruner::commandline_init() fail::CommandLine &cmd = fail::CommandLine::Inst(); SAMPLESIZE = cmd.addOption("", "samplesize", Arg::Required, "--samplesize N \tNumber of samples to take (per variant)"); + USE_KNOWN_RESULTS = cmd.addOption("", "use-known-results", Arg::None, + "--use-known-results \tReuse known results from a campaign with the 'basic' pruner " + "(abuses the DB layout to a certain degree, use with caution)"); + NO_WEIGHTING = cmd.addOption("", "no-weighting", Arg::None, + "--no-weighting \tDisable weighted sampling (weight = 1 for all ECs) " + "(don't do this unless you know what you're doing)"); return true; } @@ -38,6 +47,14 @@ bool FESamplingPruner::prune_all() } m_samplesize = strtoul(cmd[SAMPLESIZE].first()->arg, 0, 10); + if (cmd[USE_KNOWN_RESULTS]) { + m_use_known_results = true; + } + + if (cmd[NO_WEIGHTING]) { + m_weighting = false; + } + // for each variant: for (std::vector::const_iterator it = m_variants.begin(); it != m_variants.end(); ++it) { @@ -76,78 +93,129 @@ bool FESamplingPruner::sampling_prune(const fail::Database::Variant& variant) MYSQL_RES *res; MYSQL_ROW row; - LOG << "loading trace entries for " << variant.variant << "/" << variant.benchmark << " ..." << endl; + uint64_t pilotcount = 0, samplerows; - unsigned pilotcount = 0; + if (!m_use_known_results) { + LOG << "loading trace entries for " << variant.variant << "/" << variant.benchmark << " ..." << endl; - // load trace entries - ss << "SELECT instr2, instr2_absolute, data_address, time2-time1+1 AS duration" - << " FROM trace" - << " WHERE variant_id = " << variant.id - << " AND accesstype = 'R'" - << " ORDER BY duration DESC"; // speeds up sampling, but query may be slow - res = db->query_stream(ss.str().c_str()); - ss.str(""); - if (!res) return false; - while ((row = mysql_fetch_row(res))) { - Pilot p; - p.instr2 = strtoul(row[0], 0, 10); - p.instr2_absolute = strtoul(row[1], 0, 10); - p.data_address = strtoul(row[2], 0, 10); - p.duration = strtoull(row[3], 0, 10); - pop.add(p); - ++pilotcount; + // load trace entries + ss << "SELECT instr2, instr2_absolute, data_address, time2-time1+1 AS duration" + << " FROM trace" + << " WHERE variant_id = " << variant.id + << " AND accesstype = 'R'"; + res = db->query_stream(ss.str().c_str()); + ss.str(""); + if (!res) return false; + while ((row = mysql_fetch_row(res))) { + Pilot p; + p.instr2 = strtoul(row[0], 0, 10); + p.instr2_absolute = strtoul(row[1], 0, 10); + p.data_address = strtoul(row[2], 0, 10); + p.duration = m_weighting ? strtoull(row[3], 0, 10) : 1; + pop.add(p); + ++pilotcount; + } + mysql_free_result(res); + + samplerows = std::min(pilotcount, m_samplesize); + } else { + LOG << "loading pilots for " << variant.variant << "/" << variant.benchmark << " ..." << endl; + + // load fsppilot entries + ss << "SELECT p.id, p.instr2, p.data_address, t.time2 - t.time1 + 1 AS duration" + << " FROM fsppilot p" + << " JOIN trace t" + << " ON t.variant_id = p.variant_id AND t.data_address = p.data_address AND t.instr2 = p.instr2" + << " WHERE p.fspmethod_id = " << db->get_fspmethod_id("basic") + << " AND p.variant_id = " << variant.id + << " AND p.known_outcome = 0"; + res = db->query_stream(ss.str().c_str()); + ss.str(""); + if (!res) return false; + while ((row = mysql_fetch_row(res))) { + Pilot p; + p.id = strtoul(row[0], 0, 10); + p.instr2 = strtoul(row[1], 0, 10); + p.data_address = strtoul(row[2], 0, 10); + p.duration = m_weighting ? strtoull(row[3], 0, 10) : 1; + pop.add(p); + ++pilotcount; + } + mysql_free_result(res); + + samplerows = std::min(pilotcount, m_samplesize); } - mysql_free_result(res); - - unsigned samplerows = std::min(pilotcount, m_samplesize); LOG << "loaded " << pilotcount << " entries, sampling " << samplerows << " entries with fault expansion ..." << endl; - // FIXME: change strategy when trace entries have IDs, insert into fspgroup first - ss << "INSERT INTO fsppilot (known_outcome, variant_id, instr2, injection_instr, " - << "injection_instr_absolute, data_address, data_width, fspmethod_id) VALUES "; - std::string insert_sql(ss.str()); - ss.str(""); + uint64_t num_fspgroup_entries = 0; + uint32_t known_pilot_method_id = m_method_id; - for (unsigned i = 0; i < samplerows; ++i) { - uint64_t pos = my_rand(pop.get_size() - 1); - Pilot p = pop.get(pos); - ss << "(0," << variant.id << "," << p.instr2 << "," << p.instr2 - << "," << p.instr2_absolute << "," << p.data_address - << ",1," << m_method_id << ")"; - db->insert_multiple(insert_sql.c_str(), ss.str().c_str()); + if (!m_use_known_results) { + // FIXME: change strategy when trace entries have IDs, insert into fspgroup first + ss << "INSERT INTO fsppilot (known_outcome, variant_id, instr2, injection_instr, " + << "injection_instr_absolute, data_address, data_width, fspmethod_id) VALUES "; + std::string insert_sql(ss.str()); ss.str(""); + + for (uint64_t i = 0; i < samplerows; ++i) { + uint64_t pos = my_rand(pop.get_size() - 1); + Pilot p = pop.remove(pos); + ss << "(0," << variant.id << "," << p.instr2 << "," << p.instr2 + << "," << p.instr2_absolute << "," << p.data_address + << ",1," << m_method_id << ")"; + db->insert_multiple(insert_sql.c_str(), ss.str().c_str()); + ss.str(""); + } + db->insert_multiple(); + + uint64_t num_fsppilot_entries = samplerows; + + // single entry for known outcome (write access) + ss << "INSERT INTO fsppilot (known_outcome, variant_id, instr2, injection_instr, injection_instr_absolute, data_address, data_width, fspmethod_id) " + "SELECT 1, variant_id, instr2, instr2, instr2_absolute, " + " data_address, width, " << m_method_id << " " + "FROM trace " + "WHERE variant_id = " << variant.id << " AND accesstype = 'W' " + "ORDER BY instr2 ASC " + "LIMIT 1"; + if (!db->query(ss.str().c_str())) return false; + ss.str(""); + num_fsppilot_entries += db->affected_rows(); + + LOG << "created " << num_fsppilot_entries << " fsppilot entries" << std::endl; + + // fspgroup entries for sampled trace entries + ss << "INSERT INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id) " + << "SELECT p.variant_id, p.instr2, p.data_address, p.fspmethod_id, p.id " + << "FROM fsppilot p " + << "WHERE known_outcome = 0 AND p.fspmethod_id = " << m_method_id << " " + << "AND p.variant_id = " << variant.id; + + if (!db->query(ss.str().c_str())) return false; + ss.str(""); + num_fspgroup_entries = db->affected_rows(); + } else { + ss << "INSERT INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id) VALUES "; + std::string insert_sql(ss.str()); + ss.str(""); + + for (uint64_t i = 0; i < samplerows; ++i) { + uint64_t pos = my_rand(pop.get_size() - 1); + Pilot p = pop.remove(pos); + ss << "(" << variant.id << "," << p.instr2 + << "," << p.data_address << "," << m_method_id + << "," << p.id << ")"; + db->insert_multiple(insert_sql.c_str(), ss.str().c_str()); + ss.str(""); + } + db->insert_multiple(); + num_fspgroup_entries = samplerows; + + // the known_outcome=1 pilot has been determined with the "basic" method + known_pilot_method_id = db->get_fspmethod_id("basic"); } - db->insert_multiple(); - unsigned num_fsppilot_entries = samplerows; - - // single entry for known outcome (write access) - ss << "INSERT INTO fsppilot (known_outcome, variant_id, instr2, injection_instr, injection_instr_absolute, data_address, data_width, fspmethod_id) " - "SELECT 1, variant_id, instr2, instr2, instr2_absolute, " - " data_address, width, " << m_method_id << " " - "FROM trace " - "WHERE variant_id = " << variant.id << " AND accesstype = 'W' " - "ORDER BY instr2 ASC " - "LIMIT 1"; - if (!db->query(ss.str().c_str())) return false; - ss.str(""); - num_fsppilot_entries += db->affected_rows(); - assert(num_fsppilot_entries == (samplerows + 1)); - - LOG << "created " << num_fsppilot_entries << " fsppilot entries" << std::endl; - - // fspgroup entries for sampled trace entries - ss << "INSERT INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id) " - << "SELECT p.variant_id, p.instr2, p.data_address, p.fspmethod_id, p.id " - << "FROM fsppilot p " - << "WHERE known_outcome = 0 AND p.fspmethod_id = " << m_method_id << " " - << "AND p.variant_id = " << variant.id; - - if (!db->query(ss.str().c_str())) return false; - ss.str(""); - unsigned num_fspgroup_entries = db->affected_rows(); #if 0 // do it like the basic pruner: // fspgroup entries for known (W) trace entries @@ -162,9 +230,9 @@ bool FESamplingPruner::sampling_prune(const fail::Database::Variant& variant) // for each W); this needs to be accounted for at data analysis time, // though. ss << "INSERT INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id) " - "SELECT variant_id, instr2, data_address, fspmethod_id, id " + "SELECT variant_id, instr2, data_address, " << m_method_id << ", id " "FROM fsppilot " - "WHERE variant_id = " << variant.id << " AND known_outcome = 1 AND fspmethod_id = " << m_method_id; + "WHERE variant_id = " << variant.id << " AND known_outcome = 1 AND fspmethod_id = " << known_pilot_method_id; #endif if (!db->query(ss.str().c_str())) return false; ss.str(""); diff --git a/tools/prune-trace/FESamplingPruner.hpp b/tools/prune-trace/FESamplingPruner.hpp index a9538622..c938be20 100644 --- a/tools/prune-trace/FESamplingPruner.hpp +++ b/tools/prune-trace/FESamplingPruner.hpp @@ -1,6 +1,7 @@ #ifndef __FESAMPLING_PRUNER_H__ #define __FESAMPLING_PRUNER_H__ +#include #include "Pruner.hpp" #include "util/CommandLine.hpp" @@ -15,18 +16,20 @@ /// class FESamplingPruner : public Pruner { fail::CommandLine::option_handle SAMPLESIZE; + fail::CommandLine::option_handle USE_KNOWN_RESULTS; + fail::CommandLine::option_handle NO_WEIGHTING; - unsigned m_samplesize; + uint64_t m_samplesize; + bool m_use_known_results, m_weighting; public: - FESamplingPruner() : m_samplesize(0) { } + FESamplingPruner() : m_samplesize(0), m_use_known_results(false), m_weighting(true) { } virtual std::string method_name() { return "FESampling"; } virtual bool commandline_init(); virtual bool prune_all(); void getAliases(std::deque *aliases) { aliases->push_back("FESamplingPruner"); - aliases->push_back("sampling"); } private: diff --git a/tools/prune-trace/Pruner.cc b/tools/prune-trace/Pruner.cc index 3e8c7e10..a3491b75 100644 --- a/tools/prune-trace/Pruner.cc +++ b/tools/prune-trace/Pruner.cc @@ -13,7 +13,7 @@ bool Pruner::init( const std::vector& variants_exclude, const std::vector& benchmarks, const std::vector& benchmarks_exclude, - bool overwrite) + bool overwrite, bool incremental) { m_variants = db->get_variants( variants, variants_exclude, @@ -26,8 +26,8 @@ bool Pruner::init( << std::endl; // make sure we only prune variants that haven't been pruned previously - // (unless we run with --overwrite) - if (!overwrite) { + // (unless we run with --overwrite or --incremental) + if (!overwrite && !incremental) { for (std::vector::iterator it = m_variants.begin(); it != m_variants.end(); ) { std::stringstream ss; @@ -100,6 +100,7 @@ bool Pruner::create_database() { " data_address int(10) unsigned NOT NULL," " fspmethod_id int(11) NOT NULL," " pilot_id int(11) NOT NULL," + " weight int(11) UNSIGNED," " PRIMARY KEY (variant_id, data_address, instr2, fspmethod_id)," " KEY joinresults (pilot_id,fspmethod_id)) engine=MyISAM"; diff --git a/tools/prune-trace/Pruner.hpp b/tools/prune-trace/Pruner.hpp index a0ec044d..30fd4e63 100644 --- a/tools/prune-trace/Pruner.hpp +++ b/tools/prune-trace/Pruner.hpp @@ -21,7 +21,7 @@ public: const std::vector& variants_exclude, const std::vector& benchmarks, const std::vector& benchmarks_exclude, - bool overwrite); + bool overwrite, bool incremental); /** * Callback function that can be used to add command line options @@ -35,6 +35,14 @@ public: virtual bool clear_database(); virtual bool prune_all() = 0; + + /** + * Tell the pruner to work incrementally. For example, a sampling pruner + * could add more pilots to already existing ones (which already may be + * associated with fault-injection results). Returns false if the pruner + * is incapable of working in the desired mode. + */ + virtual bool set_incremental(bool incremental) { return !incremental; } }; #endif diff --git a/tools/prune-trace/SamplingPruner.cc b/tools/prune-trace/SamplingPruner.cc new file mode 100644 index 00000000..ce362fcc --- /dev/null +++ b/tools/prune-trace/SamplingPruner.cc @@ -0,0 +1,290 @@ +#include +#include +#include +#include +#include "SamplingPruner.hpp" +#include "util/Logger.hpp" +#include "util/CommandLine.hpp" +#include "util/SumTree.hpp" + +static fail::Logger LOG("SamplingPruner"); +using std::endl; + +struct WeightedPilot { + uint64_t duration; + + uint32_t id; + uint32_t instr2; + uint32_t instr2_absolute; + uint32_t data_address; + uint32_t weight; + + typedef uint64_t size_type; + size_type size() const { return duration; } +}; + +bool SamplingPruner::commandline_init() +{ + fail::CommandLine &cmd = fail::CommandLine::Inst(); + SAMPLESIZE = cmd.addOption("", "samplesize", Arg::Required, + "--samplesize N \tNumber of samples to take (per variant)"); + USE_KNOWN_RESULTS = cmd.addOption("", "use-known-results", Arg::None, + "--use-known-results \tReuse known results from a campaign with the 'basic' pruner "); + NO_WEIGHTING = cmd.addOption("", "no-weighting", Arg::None, + "--no-weighting \tDisable weighted sampling (weight = 1 for all ECs) " + "(don't do this unless you know what you're doing)"); + return true; +} + +bool SamplingPruner::prune_all() +{ + fail::CommandLine &cmd = fail::CommandLine::Inst(); + if (!cmd[SAMPLESIZE]) { + LOG << "parameter --samplesize required, aborting" << endl; + return false; + } + m_samplesize = strtoul(cmd[SAMPLESIZE].first()->arg, 0, 10); + + if (cmd[USE_KNOWN_RESULTS]) { + m_use_known_results = true; + } + + // for each variant: + for (std::vector::const_iterator it = m_variants.begin(); + it != m_variants.end(); ++it) { + if (!sampling_prune(*it)) { + return false; + } + } + + return true; +} + +// TODO: replace with a less syscall-intensive RNG +// TODO: deduplicate (copied from FESamplingPruner), put in a central place +static std::ifstream dev_urandom("/dev/urandom", std::ifstream::binary); +static uint64_t my_rand(uint64_t limit) +{ + // find smallest bitpos that satisfies (1 << bitpos) > limit + int bitpos = 0; + while (limit >> bitpos) { + bitpos++; + } + + uint64_t retval; + + do { + dev_urandom.read((char *) &retval, sizeof(retval)); + retval &= (1ULL << bitpos) - 1; + } while (retval > limit); + + return retval; +} + +bool SamplingPruner::sampling_prune(const fail::Database::Variant& variant) +{ + typedef fail::SumTree sumtree_type; + sumtree_type pop; // sample population + std::stringstream ss; + MYSQL_RES *res; + MYSQL_ROW row; + + uint64_t pilotcount = 0; + + if (!m_use_known_results) { + LOG << "loading trace entries " + << (m_incremental ? "and existing pilots " : "") + << "for " << variant.variant << "/" << variant.benchmark << " ..." << endl; + + if (!m_incremental) { + // load trace entries + ss << "SELECT instr2, instr2_absolute, data_address, time2-time1+1 AS duration" + " FROM trace" + " WHERE variant_id = " << variant.id << + " AND accesstype = 'R'"; + } else { + // load trace entries and existing pilots + ss << "SELECT t.instr2, t.instr2_absolute, t.data_address, t.time2-t.time1+1 AS duration," + " IFNULL(g.pilot_id, 0), IFNULL(g.weight, 0)" + " FROM trace t" + " LEFT JOIN fspgroup g" + " ON t.variant_id = g.variant_id AND t.data_address = g.data_address AND t.instr2 = g.instr2" + " AND g.fspmethod_id = " << m_method_id << + " WHERE t.variant_id = " << variant.id << + " AND t.accesstype = 'R'"; + } + res = db->query_stream(ss.str().c_str()); + ss.str(""); + if (!res) return false; + while ((row = mysql_fetch_row(res))) { + WeightedPilot p; + p.instr2 = strtoul(row[0], 0, 10); + p.instr2_absolute = strtoul(row[1], 0, 10); + p.data_address = strtoul(row[2], 0, 10); + p.duration = m_weighting ? strtoull(row[3], 0, 10) : 1; + p.id = m_incremental ? strtoul(row[4], 0, 10) : 0; + p.weight = m_incremental ? strtoul(row[5], 0, 10) : 0; + pop.add(p); + ++pilotcount; + } + mysql_free_result(res); + } else { + LOG << "loading pilots for " << variant.variant << "/" << variant.benchmark << " ..." << endl; + + if (!m_incremental) { + // load fsppilot entries + ss << "SELECT p.id, p.instr2, p.data_address, t.time2 - t.time1 + 1 AS duration" + " FROM fsppilot p" + " JOIN trace t" + " ON t.variant_id = p.variant_id AND t.data_address = p.data_address AND t.instr2 = p.instr2" + " WHERE p.fspmethod_id = " << db->get_fspmethod_id("basic") << + " AND p.variant_id = " << variant.id << + " AND p.known_outcome = 0"; + } else { + // load fsppilot entries and existing sampling pilots + ss << "SELECT p.id, p.instr2, p.data_address, t.time2 - t.time1 + 1 AS duration, IFNULL(g.weight, 0)" + " FROM fsppilot p" + " JOIN trace t" + " ON t.variant_id = p.variant_id AND t.data_address = p.data_address AND t.instr2 = p.instr2" + " LEFT JOIN fspgroup g" + " ON t.variant_id = g.variant_id AND t.data_address = g.data_address AND t.instr2 = g.instr2" + " AND g.fspmethod_id = " << m_method_id << + " WHERE p.fspmethod_id = " << db->get_fspmethod_id("basic") << + " AND p.variant_id = " << variant.id << + " AND p.known_outcome = 0"; + } + res = db->query_stream(ss.str().c_str()); + ss.str(""); + if (!res) return false; + while ((row = mysql_fetch_row(res))) { + WeightedPilot p; + p.id = strtoul(row[0], 0, 10); + p.instr2 = strtoul(row[1], 0, 10); + p.data_address = strtoul(row[2], 0, 10); + p.duration = m_weighting ? strtoull(row[3], 0, 10) : 1; + p.weight = m_incremental ? strtoull(row[4], 0, 10) : 0; + pop.add(p); + ++pilotcount; + } + mysql_free_result(res); + } + + LOG << "loaded " << pilotcount << " entries, sampling " + << m_samplesize << " fault-space coordinates ..." << endl; + + ss << "INSERT INTO fsppilot (known_outcome, variant_id, instr2, injection_instr, " + << "injection_instr_absolute, data_address, data_width, fspmethod_id) VALUES "; + std::string insert_sql(ss.str()); + ss.str(""); + + uint64_t popsize = pop.get_size(); // stays constant + uint64_t num_fsppilot_entries = 0; + for (uint64_t i = 0; i < m_samplesize; ++i) { + uint64_t pos = my_rand(popsize - 1); + WeightedPilot& p = pop.get(pos); + p.weight++; + // first time we sample this pilot? + if (!m_use_known_results && p.weight == 1) { + // no need to special-case existing pilots (incremental mode), as + // their initial weight is supposed to be at least 1 + ss << "(0," << variant.id << "," << p.instr2 << "," << p.instr2 + << "," << p.instr2_absolute << "," << p.data_address + << ",1," << m_method_id << ")"; + db->insert_multiple(insert_sql.c_str(), ss.str().c_str()); + ss.str(""); + ++num_fsppilot_entries; + } + } + + if (!m_use_known_results) { + db->insert_multiple(); + LOG << "created " << num_fsppilot_entries << " fsppilot entries" << std::endl; + } + + // fspgroup entries for sampled trace entries + if (!m_use_known_results) { + if (!m_incremental) { + ss << "INSERT"; + } else { + // this spares us to delete existing pilots before + ss << "REPLACE"; + } + ss << " INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id, weight) " + << "SELECT p.variant_id, p.instr2, p.data_address, " << m_method_id << ", p.id, 1 " + << "FROM fsppilot p " + << "WHERE known_outcome = 0 AND p.fspmethod_id = " << m_method_id << " " + << "AND p.variant_id = " << variant.id; + + if (!db->query(ss.str().c_str())) return false; + ss.str(""); + uint64_t num_fspgroup_entries; + if (!m_incremental) { + num_fspgroup_entries = db->affected_rows(); + } else { + // with REPLACE INTO, affected_rows does not yield the number of + // new rows; take num_fsppilot_entries instead + num_fspgroup_entries = num_fsppilot_entries; + } + LOG << "created " << num_fspgroup_entries << " fspgroup entries" << std::endl; + + // FIXME is this faster than manually INSERTing all fspgroup entries? + num_fspgroup_entries = 0; + LOG << "updating fspgroup entries with weight > 1 ..." << std::endl; + for (sumtree_type::iterator it = pop.begin(); it != pop.end(); ++it) { + if (it->weight <= 1) { + continue; + } + ++num_fspgroup_entries; + ss << "UPDATE fspgroup SET weight = " << it->weight << + " WHERE variant_id = " << variant.id << + " AND instr2 = " << it->instr2 << + " AND data_address = " << it->data_address << + " AND fspmethod_id = " << m_method_id; + // pilot_id is known but should be identical + if (!db->query(ss.str().c_str())) return false; + if (db->affected_rows() != 1) { + LOG << "something is wrong, query affected unexpected (" + << db->affected_rows() + << " != 1) number of rows: " + << ss.str() << std::endl; + } + ss.str(""); + } + + if (!m_incremental) { + LOG << "updated " << num_fspgroup_entries << " fspgroup entries" << std::endl; + } else { + // we don't know how many rows we really updated + LOG << "updated fspgroup entries" << std::endl; + } + } else { + uint64_t num_fspgroup_entries = 0; + + LOG << "creating fspgroup entries ..." << std::endl; + + if (!m_incremental) { + ss << "INSERT"; + } else { + // this spares us to delete existing pilots before + ss << "REPLACE"; + } + ss << " INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id, weight) VALUES "; + insert_sql = ss.str(); + ss.str(""); + + for (sumtree_type::iterator it = pop.begin(); it != pop.end(); ++it) { + if (it->weight == 0) { + continue; + } + ++num_fspgroup_entries; + ss << "(" << variant.id << "," << it->instr2 << "," << it->data_address + << "," << m_method_id << "," << it->id << "," << it->weight << ")"; + db->insert_multiple(insert_sql.c_str(), ss.str().c_str()); + ss.str(""); + } + db->insert_multiple(); + LOG << "created " << num_fspgroup_entries << " fspgroup entries" << std::endl; + } + + return true; +} diff --git a/tools/prune-trace/SamplingPruner.hpp b/tools/prune-trace/SamplingPruner.hpp new file mode 100644 index 00000000..9fe6a2ee --- /dev/null +++ b/tools/prune-trace/SamplingPruner.hpp @@ -0,0 +1,39 @@ +#ifndef __SAMPLING_PRUNER_H__ +#define __SAMPLING_PRUNER_H__ + +#include +#include "Pruner.hpp" +#include "util/CommandLine.hpp" + +/// +/// SamplingPruner: implements sampling with equivalence-class reuse +/// +/// Unlike the FESamplingPruner, the SamplingPruner implements uniform +/// fault-space sampling that counts multiple hits of an equivalence class. +/// +class SamplingPruner : public Pruner { + fail::CommandLine::option_handle SAMPLESIZE; + fail::CommandLine::option_handle USE_KNOWN_RESULTS; + fail::CommandLine::option_handle NO_WEIGHTING; + + uint64_t m_samplesize; + bool m_use_known_results, m_weighting, m_incremental; + +public: + SamplingPruner() : m_samplesize(0), m_use_known_results(false), m_weighting(true), m_incremental(false) { } + virtual std::string method_name() { return "sampling"; } + virtual bool commandline_init(); + virtual bool prune_all(); + + void getAliases(std::deque *aliases) { + aliases->push_back("SamplingPruner"); + aliases->push_back("sampling"); + } + + virtual bool set_incremental(bool incremental) { m_incremental = incremental; return true; } + +private: + bool sampling_prune(const fail::Database::Variant& variant); +}; + +#endif diff --git a/tools/prune-trace/main.cc b/tools/prune-trace/main.cc index 295d5c9f..0a391889 100644 --- a/tools/prune-trace/main.cc +++ b/tools/prune-trace/main.cc @@ -14,6 +14,7 @@ using std::endl; #include "Pruner.hpp" #include "BasicPruner.hpp" #include "FESamplingPruner.hpp" +#include "SamplingPruner.hpp" int main(int argc, char *argv[]) { std::string username, hostname, database; @@ -26,6 +27,8 @@ int main(int argc, char *argv[]) { registry.add(&basicprunerleft); FESamplingPruner fesamplingpruner; registry.add(&fesamplingpruner); + SamplingPruner samplingpruner; + registry.add(&samplingpruner); std::string pruners = registry.getPrimeAliasesCSV(); @@ -62,12 +65,20 @@ int main(int argc, char *argv[]) { CommandLine::option_handle OVERWRITE = cmd.addOption("", "overwrite", Arg::None, "--overwrite \tOverwrite already existing pruning data (the default is to skip variants with existing entries)"); + CommandLine::option_handle INCREMENTAL = + cmd.addOption("", "incremental", Arg::None, + "--incremental \tTell the pruner to work incrementally (if supported)"); if (!cmd.parse()) { std::cerr << "Error parsing arguments." << std::endl; exit(-1); } + if (cmd[OVERWRITE] && cmd[INCREMENTAL]) { + std::cerr << "--overwrite and --incremental cannot be used together." << std::endl; + exit(-1); + } + Pruner *pruner; std::string pruner_name = "BasicPruner"; if (cmd[PRUNER]) { @@ -107,6 +118,11 @@ int main(int argc, char *argv[]) { Database *db = Database::cmdline_connect(); pruner->set_db(db); + if (cmd[INCREMENTAL] && !pruner->set_incremental(true)) { + std::cerr << "Pruner is incapable of running incrementally" << std::endl; + exit(-1); + } + std::vector variants, benchmarks, variants_exclude, benchmarks_exclude; if (cmd[VARIANT]) { for (option::Option *o = cmd[VARIANT]; o; o = o->next()) { @@ -147,7 +163,8 @@ int main(int argc, char *argv[]) { exit(-1); } - if (!pruner->init(variants, variants_exclude, benchmarks, benchmarks_exclude, cmd[OVERWRITE])) { + if (!pruner->init(variants, variants_exclude, benchmarks, benchmarks_exclude, + cmd[OVERWRITE], cmd[INCREMENTAL])) { LOG << "pruner->init() failed" << endl; exit(-1); } @@ -155,7 +172,7 @@ int main(int argc, char *argv[]) { //////////////////////////////////////////////////////////////// // Do the actual pruning //////////////////////////////////////////////////////////////// - if (!cmd[NO_DELETE] && cmd[OVERWRITE] && !pruner->clear_database()) { + if (!cmd[NO_DELETE] && cmd[OVERWRITE] && !cmd[INCREMENTAL] && !pruner->clear_database()) { LOG << "clear_database() failed" << endl; exit(-1); }