diff --git a/src/core/util/BlackholeLogger.hpp b/src/core/util/BlackholeLogger.hpp new file mode 100644 index 00000000..6ef1674f --- /dev/null +++ b/src/core/util/BlackholeLogger.hpp @@ -0,0 +1,25 @@ +#ifndef __BLACKHOLE_LOGGER_HPP__ +#define __BLACKHOLE_LOGGER_HPP__ + +#include + +namespace fail { + +/** + * \class BlackholeLogger + * A /dev/null sink as a drop-in replacement for Logger. Should be completely + * optimized away on non-trivial optimization levels. + */ +class BlackholeLogger { +public: + Logger(const std::string& description = "Fail*", bool show_time = true, + std::ostream& dest = std::cout) { } + void setDescription(const std::string& descr) { } + void showTime(bool choice) { } + template + inline std::ostream& operator <<(const T& v) { } +}; + +} // end-of-namespace: fail + +#endif // __BLACKHOLE_LOGGER_HPP__ diff --git a/src/core/util/CMakeLists.txt b/src/core/util/CMakeLists.txt index 06052098..b8cd483d 100644 --- a/src/core/util/CMakeLists.txt +++ b/src/core/util/CMakeLists.txt @@ -89,3 +89,6 @@ endif (BUILD_LLVM_DISASSEMBLER) add_executable(memorymap-test testing/memorymap-test.cc) target_link_libraries(memorymap-test fail-util) add_test(NAME memorymap-test WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/testing COMMAND memorymap-test) + +add_executable(sumtree-test testing/SumTreeTest.cc) +add_test(NAME sumtree-test WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/testing COMMAND sumtree-test) diff --git a/src/core/util/Database.cc b/src/core/util/Database.cc index 398b0299..a5a64a17 100644 --- a/src/core/util/Database.cc +++ b/src/core/util/Database.cc @@ -141,34 +141,81 @@ my_ulonglong Database::insert_id() return mysql_insert_id(handle); } -std::vector Database::get_variants(const std::string &variant, const std::string &benchmark) { - std::vector result; - +bool Database::create_variants_table() +{ if (!query("CREATE TABLE IF NOT EXISTS variant (" - " id int(11) NOT NULL AUTO_INCREMENT," - " variant varchar(255) NOT NULL," - " benchmark varchar(255) NOT NULL," - " PRIMARY KEY (id)," - "UNIQUE KEY variant (variant,benchmark)) ENGINE=MyISAM")) { + " id int(11) NOT NULL AUTO_INCREMENT," + " variant varchar(255) NOT NULL," + " benchmark varchar(255) NOT NULL," + " PRIMARY KEY (id)," + "UNIQUE KEY variant (variant,benchmark)) ENGINE=MyISAM")) { + return false; + } + return true; +} + +std::vector Database::get_variants(const std::string &variant, const std::string &benchmark) +{ + std::vector variants; + variants.push_back(variant); + std::vector benchmarks; + benchmarks.push_back(benchmark); + std::vector dummy; + + return get_variants(variants, dummy, benchmarks, dummy); +} + +std::vector Database::get_variants( + const std::vector& variants, + const std::vector& variants_exclude, + const std::vector& benchmarks, + const std::vector& benchmarks_exclude) +{ + std::vector result; + std::stringstream ss; + + // make sure variant table exists + if (!create_variants_table()) { return result; } - std::stringstream ss; - // FIXME SQL injection possible - ss << "SELECT id, variant, benchmark FROM variant WHERE variant LIKE '" << variant << "' AND benchmark LIKE '" << benchmark << "'"; - MYSQL_RES *variant_id_res = query(ss.str().c_str(), true); + // FIXME string escaping + ss << "SELECT id, variant, benchmark FROM variant WHERE "; + ss << "("; + for (std::vector::const_iterator it = variants.begin(); + it != variants.end(); ++it) { + ss << "variant LIKE '" << *it << "' OR "; + } + ss << "0) AND ("; + for (std::vector::const_iterator it = benchmarks.begin(); + it != benchmarks.end(); ++it) { + ss << "benchmark LIKE '" << *it << "' OR "; + } + // dummy terminator to avoid special cases in query construction above + ss << "0) AND NOT ("; + for (std::vector::const_iterator it = variants_exclude.begin(); + it != variants_exclude.end(); ++it) { + ss << "variant LIKE '" << *it << "' OR "; + } + for (std::vector::const_iterator it = benchmarks_exclude.begin(); + it != benchmarks_exclude.end(); ++it) { + ss << "benchmark LIKE '" << *it << "' OR "; + } + // dummy terminator to avoid special cases in query construction above + ss << "0)"; + MYSQL_RES *variant_id_res = query(ss.str().c_str(), true); if (!variant_id_res) { return result; - } else if (mysql_num_rows(variant_id_res)) { - for (unsigned int i = 0; i < mysql_num_rows(variant_id_res); ++i) { - MYSQL_ROW row = mysql_fetch_row(variant_id_res); - Variant var; - var.id = atoi(row[0]); - var.variant = std::string(row[1]); - var.benchmark = std::string(row[2]); - result.push_back(var); - } + } + + MYSQL_ROW row; + while ((row = mysql_fetch_row(variant_id_res))) { + Variant var; + var.id = atoi(row[0]); + var.variant = row[1]; + var.benchmark = row[2]; + result.push_back(var); } return result; diff --git a/src/core/util/Database.hpp b/src/core/util/Database.hpp index 626b9eaa..3a1ed817 100644 --- a/src/core/util/Database.hpp +++ b/src/core/util/Database.hpp @@ -54,6 +54,16 @@ namespace fail { */ std::vector get_variants(const std::string &variant, const std::string &benchmark); + /** + * Get all variants that fit one of the variant, one of the benchmark, + * and none of the variant/benchmark exclude patterns (will be queried + * with SQL LIKE). + */ + std::vector get_variants( + const std::vector& variants, + const std::vector& variants_exclude, + const std::vector& benchmarks, + const std::vector& benchmarks_exclude); /** * Get the fault space pruning method id for a specific @@ -113,6 +123,9 @@ namespace fail { */ static void cmdline_setup(); static Database * cmdline_connect(); + + private: + bool create_variants_table(); }; } diff --git a/src/core/util/SumTree.hpp b/src/core/util/SumTree.hpp new file mode 100644 index 00000000..998e3e05 --- /dev/null +++ b/src/core/util/SumTree.hpp @@ -0,0 +1,189 @@ +#ifndef __SUM_TREE_HPP__ +#define __SUM_TREE_HPP__ + +#include +#include +#include + +// The SumTree implements an efficient tree data structure for +// "roulette-wheel" sampling, or "sampling with fault expansion", i.e., +// sampling of trace entries / pilots without replacement and with a +// picking probability proportional to the entries' sizes. +// +// For every sample, the naive approach picks a random number between 0 +// and the sum of all entry sizes minus one. It then iterates over all +// entries and sums their sizes until the sum exceeds the random number. +// The current entry gets picked. The main disadvantage is the linear +// complexity, which gets unpleasant for millions of entries. +// +// The core idea behind the SumTree implementation is to maintain the +// size sum of groups of entries, kept in "buckets". Thereby, a bucket +// can be quickly jumped over. To keep bucket sizes (and thereby linear +// search times) bounded, more bucket hierarchy levels are introduced +// when a defined bucket size limit is reached. +// +// Note that the current implementation is built for a pure growth phase +// (when the tree gets filled with pilots from the database), followed by +// a sampling phase when the tree gets emptied. It does not handle a +// mixed add/remove case very smartly, although it should remain +// functional. + +namespace fail { + +template +class SumTree { + //! Bucket data structure for tree nodes + struct Bucket { + Bucket() : size(0) {} + ~Bucket(); + //! Sum of all children / elements + typename T::size_type size; + //! Sub-buckets, empty for leaf nodes + std::vector children; + //! Contained elements, empty for inner nodes + std::vector elements; + }; + + //! Root node + Bucket *m_root; + //! Tree depth: nodes at level m_depth are leaf nodes, others are inner nodes + unsigned m_depth; +public: + SumTree() : m_root(new Bucket), m_depth(0) {} + ~SumTree() { delete m_root; } + //! Adds a new element to the tree. + void add(const T& element); + //! Retrieves (and removes) element at random number position. + T get(typename T::size_type pos) { return get(pos, m_root, 0); } + //! Yields the sum over all elements in the tree. + typename T::size_type get_size() const { return m_root->size; } +private: + //! Internal, recursive version of add(). + bool add(Bucket **node, const T& element, unsigned depth_remaining); + //! Internal, recursive version of get(). + T get(typename T::size_type pos, Bucket *node, typename T::size_type sum); +}; + +// template implementation + +template +SumTree::Bucket::~Bucket() +{ + for (typename std::vector::const_iterator it = children.begin(); + it != children.end(); ++it) { + delete *it; + } +} + +template +void SumTree::add(const T& element) +{ + if (element.size() == 0) { + // pilots with size == 0 cannot be picked anyways + return; + } + + if (add(&m_root, element, m_depth)) { + // tree wasn't full yet, add succeeded + return; + } + + // tree is full, move everything one level down + ++m_depth; + Bucket *b = new Bucket; + b->children.push_back(m_root); + b->size = m_root->size; + m_root = b; + + // retry + add(&m_root, element, m_depth); +} + +template +bool SumTree::add(Bucket **node, const T& element, unsigned depth_remaining) +{ + // non-leaf node? + if (depth_remaining) { + // no children yet? create one. + if ((*node)->children.size() == 0) { + (*node)->children.push_back(new Bucket); + } + + // adding to newest child worked? + if (add(&(*node)->children.back(), element, depth_remaining - 1)) { + (*node)->size += element.size(); + return true; + } + + // newest child full, may we create another one? + if ((*node)->children.size() < BUCKETSIZE) { + (*node)->children.push_back(new Bucket); + add(&(*node)->children.back(), element, depth_remaining - 1); + (*node)->size += element.size(); + return true; + } + // recursive add ultimately failed, subtree full + return false; + + // leaf node + } else { + if ((*node)->elements.size() < BUCKETSIZE) { + (*node)->elements.push_back(element); + (*node)->size += element.size(); + return true; + } + return false; + } +} + +template +T SumTree::get(typename T::size_type pos, Bucket *node, typename T::size_type sum) +{ + // sanity check + assert(pos >= sum && pos < sum + node->size); + + // will only be entered for inner nodes + for (typename std::vector::iterator it = node->children.begin(); + it != node->children.end(); ) { + sum += (*it)->size; + if (sum <= pos) { + ++it; + continue; + } + + // found containing bucket, recurse + sum -= (*it)->size; + T e = get(pos, *it, sum); + node->size -= e.size(); + // remove empty (or, at least, zero-sized) child? + if ((*it)->size == 0) { + delete *it; + node->children.erase(it); + } + return e; + } + + // will only be entered for leaf nodes + for (typename std::vector::iterator it = node->elements.begin(); + it != node->elements.end(); ) { + sum += it->size(); + if (sum <= pos) { + ++it; + continue; + } + + // found pilot + T e = *it; + node->size -= e.size(); + node->elements.erase(it); + return e; + } + + // this should never happen + assert(0); + return T(); +} + +} // namespace + +#endif diff --git a/src/core/util/testing/SumTreeTest.cc b/src/core/util/testing/SumTreeTest.cc new file mode 100644 index 00000000..1757cd17 --- /dev/null +++ b/src/core/util/testing/SumTreeTest.cc @@ -0,0 +1,34 @@ +#include "util/SumTree.hpp" + +#include +#define LOG std::cerr + +using std::endl; + +struct Pilot { + uint32_t id; + uint32_t instr2; + uint32_t data_address; + uint64_t duration; + + typedef uint64_t size_type; + size_type size() const { return duration; } +}; + +int main() +{ + fail::SumTree tree; + for (int i = 0; i <= 20; ++i) { + Pilot p; + p.duration = i; + tree.add(p); + } + + while (tree.get_size() > 0) { + uint64_t pos = tree.get_size() / 2; + LOG << "MAIN tree.get_size() = " << tree.get_size() + << ", trying to retrieve pos = " << pos << endl; + Pilot p = tree.get(pos); + LOG << "MAIN retrieved pilot with duration " << p.duration << endl; + } +} diff --git a/tools/prune-trace/BasicPruner.cc b/tools/prune-trace/BasicPruner.cc index 9c77b988..5fbc1d9d 100644 --- a/tools/prune-trace/BasicPruner.cc +++ b/tools/prune-trace/BasicPruner.cc @@ -19,22 +19,21 @@ bool BasicPruner::prune_all() { "SELECT 0, variant_id, instr2, " << injection_instr << ", " << injection_instr_absolute << ", " " data_address, width, " << m_method_id << " " "FROM trace " - "WHERE variant_id IN (" << m_variant_id_query << ") AND accesstype = 'R'"; + "WHERE variant_id IN (" << m_variants_sql << ") AND accesstype = 'R'"; if (!db->query(ss.str().c_str())) return false; ss.str(""); int rows = db->affected_rows(); // for each variant: - MYSQL_RES *res = db->query(m_variant_id_query.c_str(), true); - MYSQL_ROW row; - while ((row = mysql_fetch_row(res))) { + for (std::vector::const_iterator it = m_variants.begin(); + it != m_variants.end(); ++it) { // single entry for known outcome (write access) ss << "INSERT INTO fsppilot (known_outcome, variant_id, instr2, injection_instr, injection_instr_absolute, data_address, data_width, fspmethod_id) " "SELECT 1, variant_id, instr2, " << injection_instr << ", " << injection_instr_absolute << ", " " data_address, width, " << m_method_id << " " "FROM trace " - "WHERE variant_id = " << row[0] << " AND accesstype = 'W' " + "WHERE variant_id = " << it->id << " AND accesstype = 'W' " "ORDER BY instr2 ASC " "LIMIT 1"; if (!db->query(ss.str().c_str())) return false; @@ -50,7 +49,7 @@ bool BasicPruner::prune_all() { << "JOIN trace t ON t.variant_id = p.variant_id AND t.instr2 = p.instr2" << " AND t.data_address = p.data_address " << "WHERE known_outcome = 0 AND p.fspmethod_id = " << m_method_id << " " - << "AND p.variant_id IN (" << m_variant_id_query << ")"; + << "AND p.variant_id IN (" << m_variants_sql << ")"; if (!db->query(ss.str().c_str())) return false; ss.str(""); @@ -61,7 +60,7 @@ bool BasicPruner::prune_all() { "FROM fsppilot p " "JOIN trace t " "ON t.variant_id = p.variant_id AND p.fspmethod_id = " << m_method_id << " AND p.known_outcome = 1 " - "WHERE t.variant_id IN (" << m_variant_id_query << ") AND t.accesstype = 'W'"; + "WHERE t.variant_id IN (" << m_variants_sql << ") AND t.accesstype = 'W'"; if (!db->query(ss.str().c_str())) return false; ss.str(""); rows += db->affected_rows(); diff --git a/tools/prune-trace/CMakeLists.txt b/tools/prune-trace/CMakeLists.txt index a79e3451..4043ffd4 100644 --- a/tools/prune-trace/CMakeLists.txt +++ b/tools/prune-trace/CMakeLists.txt @@ -1,6 +1,7 @@ set(SRCS Pruner.cc BasicPruner.cc + FESamplingPruner.cc ) find_package(MySQL REQUIRED) diff --git a/tools/prune-trace/FESamplingPruner.cc b/tools/prune-trace/FESamplingPruner.cc new file mode 100644 index 00000000..fb4f925b --- /dev/null +++ b/tools/prune-trace/FESamplingPruner.cc @@ -0,0 +1,176 @@ +#include +#include +#include +#include +#include "FESamplingPruner.hpp" +#include "util/Logger.hpp" +#include "util/CommandLine.hpp" +#include "util/SumTree.hpp" + +static fail::Logger LOG("FESamplingPruner"); +using std::endl; + +struct Pilot { + uint64_t duration; + + uint32_t instr2; + uint32_t instr2_absolute; + uint32_t data_address; + + typedef uint64_t size_type; + size_type size() const { return duration; } +}; + +bool FESamplingPruner::commandline_init() +{ + fail::CommandLine &cmd = fail::CommandLine::Inst(); + SAMPLESIZE = cmd.addOption("", "samplesize", Arg::Required, + "--samplesize N \tNumber of samples to take (per variant)"); + return true; +} + +bool FESamplingPruner::prune_all() +{ + fail::CommandLine &cmd = fail::CommandLine::Inst(); + if (!cmd[SAMPLESIZE]) { + LOG << "parameter --samplesize required, aborting" << endl; + return false; + } + m_samplesize = strtoul(cmd[SAMPLESIZE].first()->arg, 0, 10); + + // for each variant: + for (std::vector::const_iterator it = m_variants.begin(); + it != m_variants.end(); ++it) { + if (!sampling_prune(*it)) { + return false; + } + } + + return true; +} + +// TODO: replace with a less syscall-intensive RNG +static std::ifstream dev_urandom("/dev/urandom", std::ifstream::binary); +static uint64_t my_rand(uint64_t limit) +{ + // find smallest bitpos that satisfies (1 << bitpos) > limit + int bitpos = 0; + while (limit >> bitpos) { + bitpos++; + } + + uint64_t retval; + + do { + dev_urandom.read((char *) &retval, sizeof(retval)); + retval &= (1ULL << bitpos) - 1; + } while (retval > limit); + + return retval; +} + +bool FESamplingPruner::sampling_prune(const fail::Database::Variant& variant) +{ + fail::SumTree pop; // sample population + std::stringstream ss; + MYSQL_RES *res; + MYSQL_ROW row; + + LOG << "loading trace entries for " << variant.variant << "/" << variant.benchmark << " ..." << endl; + + unsigned pilotcount = 0; + + // load trace entries + ss << "SELECT instr2, instr2_absolute, data_address, time2-time1+1 AS duration" + << " FROM trace" + << " WHERE variant_id = " << variant.id + << " AND accesstype = 'R'" + << " ORDER BY duration DESC"; // speeds up sampling, but query may be slow + res = db->query_stream(ss.str().c_str()); + ss.str(""); + if (!res) return false; + while ((row = mysql_fetch_row(res))) { + Pilot p; + p.instr2 = strtoul(row[0], 0, 10); + p.instr2_absolute = strtoul(row[1], 0, 10); + p.data_address = strtoul(row[2], 0, 10); + p.duration = strtoull(row[3], 0, 10); + pop.add(p); + ++pilotcount; + } + mysql_free_result(res); + + unsigned samplerows = std::min(pilotcount, m_samplesize); + + LOG << "loaded " << pilotcount << " entries, sampling " + << samplerows << " entries with fault expansion ..." << endl; + + // FIXME: change strategy when trace entries have IDs, insert into fspgroup first + ss << "INSERT INTO fsppilot (known_outcome, variant_id, instr2, injection_instr, " + << "injection_instr_absolute, data_address, data_width, fspmethod_id) VALUES "; + std::string insert_sql(ss.str()); + ss.str(""); + + for (unsigned i = 0; i < samplerows; ++i) { + uint64_t pos = my_rand(pop.get_size() - 1); + Pilot p = pop.get(pos); + ss << "(0," << variant.id << "," << p.instr2 << "," << p.instr2 + << "," << p.instr2_absolute << "," << p.data_address + << ",1," << m_method_id << ")"; + db->insert_multiple(insert_sql.c_str(), ss.str().c_str()); + ss.str(""); + } + db->insert_multiple(); + unsigned num_fsppilot_entries = samplerows; + + // single entry for known outcome (write access) + ss << "INSERT INTO fsppilot (known_outcome, variant_id, instr2, injection_instr, injection_instr_absolute, data_address, data_width, fspmethod_id) " + "SELECT 1, variant_id, instr2, instr2, instr2_absolute, " + " data_address, width, " << m_method_id << " " + "FROM trace " + "WHERE variant_id = " << variant.id << " AND accesstype = 'W' " + "ORDER BY instr2 ASC " + "LIMIT 1"; + if (!db->query(ss.str().c_str())) return false; + ss.str(""); + num_fsppilot_entries += db->affected_rows(); + assert(num_fsppilot_entries == (samplerows + 1)); + + LOG << "created " << num_fsppilot_entries << " fsppilot entries" << std::endl; + + // fspgroup entries for sampled trace entries + ss << "INSERT INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id) " + << "SELECT p.variant_id, p.instr2, p.data_address, p.fspmethod_id, p.id " + << "FROM fsppilot p " + << "WHERE known_outcome = 0 AND p.fspmethod_id = " << m_method_id << " " + << "AND p.variant_id = " << variant.id; + + if (!db->query(ss.str().c_str())) return false; + ss.str(""); + unsigned num_fspgroup_entries = db->affected_rows(); + +#if 0 // do it like the basic pruner: + // fspgroup entries for known (W) trace entries + ss << "INSERT INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id) " + "SELECT STRAIGHT_JOIN t.variant_id, t.instr2, t.data_address, p.fspmethod_id, p.id " + "FROM fsppilot p " + "JOIN trace t " + "ON t.variant_id = p.variant_id AND p.fspmethod_id = " << m_method_id << " AND p.known_outcome = 1 " + "WHERE t.variant_id = " << variant.id << " AND t.accesstype = 'W'"; +#else + // *one* fspgroup entry for known (W) trace entries (no need to create one + // for each W); this needs to be accounted for at data analysis time, + // though. + ss << "INSERT INTO fspgroup (variant_id, instr2, data_address, fspmethod_id, pilot_id) " + "SELECT variant_id, instr2, data_address, fspmethod_id, id " + "FROM fsppilot " + "WHERE variant_id = " << variant.id << " AND known_outcome = 1 AND fspmethod_id = " << m_method_id; +#endif + if (!db->query(ss.str().c_str())) return false; + ss.str(""); + num_fspgroup_entries += db->affected_rows(); + + LOG << "created " << num_fspgroup_entries << " fspgroup entries" << std::endl; + + return true; +} diff --git a/tools/prune-trace/FESamplingPruner.hpp b/tools/prune-trace/FESamplingPruner.hpp new file mode 100644 index 00000000..6d4dc0cf --- /dev/null +++ b/tools/prune-trace/FESamplingPruner.hpp @@ -0,0 +1,31 @@ +#ifndef __FESAMPLING_PRUNER_H__ +#define __FESAMPLING_PRUNER_H__ + +#include "Pruner.hpp" +#include "util/CommandLine.hpp" + +/// +/// FESamplingPruner: implements sampling with Fault Expansion +/// +/// The FESamplingPruner implements the fault-expansion variance reduction +/// technique (FE-VRT) as described in: Smith, D. Todd and Johnson, Barry W. +/// and Andrianos, Nikos and Profeta, III, Joseph A., "A variance-reduction +/// technique via fault-expansion for fault-coverage estimation" (1997), +/// 366--374. +/// +class FESamplingPruner : public Pruner { + fail::CommandLine::option_handle SAMPLESIZE; + + unsigned m_samplesize; + +public: + FESamplingPruner() : m_samplesize(0) { } + virtual std::string method_name() { return "FESampling"; } + virtual bool commandline_init(); + virtual bool prune_all(); + +private: + bool sampling_prune(const fail::Database::Variant& variant); +}; + +#endif diff --git a/tools/prune-trace/Pruner.cc b/tools/prune-trace/Pruner.cc index a3d339c5..644a9d0b 100644 --- a/tools/prune-trace/Pruner.cc +++ b/tools/prune-trace/Pruner.cc @@ -15,29 +15,24 @@ bool Pruner::init(fail::Database *db, const std::vector& benchmarks_exclude) { this->db = db; - std::stringstream ss; - // FIXME string escaping - ss << "SELECT id FROM variant WHERE "; - for (std::vector::const_iterator it = variants.begin(); - it != variants.end(); ++it) { - ss << "variant LIKE '" << *it << "' AND "; + m_variants = db->get_variants( + variants, variants_exclude, + benchmarks, benchmarks_exclude); + if (m_variants.size() == 0) { + LOG << "no variants found, nothing to do" << std::endl; + return false; } - for (std::vector::const_iterator it = variants_exclude.begin(); - it != variants_exclude.end(); ++it) { - ss << "variant NOT LIKE '" << *it << "' AND "; + + std::stringstream ss; + for (std::vector::const_iterator it = m_variants.begin(); + it != m_variants.end(); ++it) { + if (it != m_variants.begin()) { + ss << ","; + } + ss << it->id; } - for (std::vector::const_iterator it = benchmarks.begin(); - it != benchmarks.end(); ++it) { - ss << "benchmark LIKE '" << *it << "' AND "; - } - for (std::vector::const_iterator it = benchmarks_exclude.begin(); - it != benchmarks_exclude.end(); ++it) { - ss << "benchmark NOT LIKE '" << *it << "' AND "; - } - // dummy terminator to avoid special cases in query construction above - ss << "1"; - m_variant_id_query = ss.str(); + m_variants_sql = ss.str(); if (!(m_method_id = db->get_fspmethod_id(method_name()))) { return false; @@ -79,12 +74,14 @@ bool Pruner::create_database() { bool Pruner::clear_database() { std::stringstream ss; - ss << "DELETE FROM fsppilot WHERE variant_id IN (" << m_variant_id_query << ")"; + ss << "DELETE FROM fsppilot WHERE variant_id IN (" << m_variants_sql + << ") AND fspmethod_id = " << m_method_id; bool ret = (bool) db->query(ss.str().c_str()); LOG << "deleted " << db->affected_rows() << " rows from fsppilot table" << std::endl; ss.str(""); - ss << "DELETE FROM fspgroup WHERE variant_id IN (" << m_variant_id_query << ")"; + ss << "DELETE FROM fspgroup WHERE variant_id IN (" << m_variants_sql + << ") AND fspmethod_id = " << m_method_id; ret = ret && (bool) db->query(ss.str().c_str()); LOG << "deleted " << db->affected_rows() << " rows from fspgroup table" << std::endl; diff --git a/tools/prune-trace/Pruner.hpp b/tools/prune-trace/Pruner.hpp index cbbb2113..e6e1f5dc 100644 --- a/tools/prune-trace/Pruner.hpp +++ b/tools/prune-trace/Pruner.hpp @@ -8,8 +8,9 @@ class Pruner { protected: int m_method_id; - std::string m_variant_id_query; fail::Database *db; + std::vector m_variants; + std::string m_variants_sql; public: bool init(fail::Database *db, @@ -18,6 +19,12 @@ public: const std::vector& benchmarks, const std::vector& benchmarks_exclude); + /** + * Callback function that can be used to add command line options + * to the cmd interface + */ + virtual bool commandline_init() { return true; } + virtual std::string method_name() = 0; virtual bool create_database(); diff --git a/tools/prune-trace/main.cc b/tools/prune-trace/main.cc index 709c22b1..55d4c47a 100644 --- a/tools/prune-trace/main.cc +++ b/tools/prune-trace/main.cc @@ -11,6 +11,7 @@ using std::endl; #include "Pruner.hpp" #include "BasicPruner.hpp" +#include "FESamplingPruner.hpp" int main(int argc, char *argv[]) { std::string username, hostname, database; @@ -59,6 +60,9 @@ int main(int argc, char *argv[]) { } else if (imp == "BasicPrunerLeft" || imp == "basic-left") { LOG << "Using BasicPruner (use left border, instr1)" << endl; pruner = new BasicPruner(true); + } else if (imp == "FESamplingPruner" || imp == "sampling") { + LOG << "Using FESamplingPruner" << endl; + pruner = new FESamplingPruner; } else { LOG << "Unknown pruning method: " << imp << endl; @@ -70,6 +74,14 @@ int main(int argc, char *argv[]) { pruner = new BasicPruner(); } + if (pruner && !(pruner->commandline_init())) { + std::cerr << "Pruner's commandline initialization failed" << std::endl; + exit(-1); + } + // Since the pruner might have added command line options, we need to + // reparse all arguments. + cmd.parse(); + if (cmd[HELP]) { cmd.printUsage(); exit(0);