1#ifndef BTLLIB_MI_BLOOM_FILTER_HPP
2#define BTLLIB_MI_BLOOM_FILTER_HPP
4#include "btllib/nthash.hpp"
5#include "btllib/status.hpp"
7#include "sdsl/bit_vector_il.hpp"
8#include "sdsl/rank_support.hpp"
30 static const T MASK = 1 << (
sizeof(T) * 8 - 1);
31 static const T ANTI_MASK = (T)~MASK;
33 static const T STRAND = 1 << (
sizeof(T) * 8 - 2);
34 static const T ANTI_STRAND = (T)~STRAND;
36 static const T ID_MASK = ANTI_STRAND & ANTI_MASK;
38 static const unsigned BLOCKSIZE = 512;
41 static inline double calc_prob_single_frame(
double occupancy,
44 unsigned allowed_misses)
46 double prob_total = 0.0;
47 for (
unsigned i = hash_num - allowed_misses; i <= hash_num; i++) {
48 double prob = n_choose_k(hash_num, i);
49 prob *= pow(occupancy, i);
50 prob *= pow(1.0 - occupancy, hash_num - i);
51 prob *= (1.0 - pow(1.0 - freq, i));
57 static inline double calc_prob_single(
double occupancy,
double freq)
59 return occupancy * freq;
65 static size_t calc_optimal_size(
size_t entries,
69 auto non_64_approx_val =
70 size_t(-
double(entries) *
double(hash_num) / log(1.0 - occupancy));
72 return non_64_approx_val + (magic - non_64_approx_val % magic);
80 static unsigned insert(sdsl::bit_vector& bv,
81 const uint64_t* hash_values,
84 unsigned colli_count = 0;
85 for (
unsigned i = 0; i < hash_num; ++i) {
86 const int magic = 0x3f;
87 uint64_t pos = hash_values[i] % bv.size();
88 uint64_t* data_index = bv.data() + (pos >> 6);
89 uint64_t bit_mask_value = (uint64_t)1 << (pos & magic);
91 __sync_fetch_and_or(data_index, bit_mask_value) >> (pos & magic) & 1;
117 sdsl::bit_vector& bv,
118 const std::vector<std::string>& seeds = std::vector<std::string>(0))
120 , m_hash_num(hash_num)
121 , m_kmer_size(kmer_size)
123 , m_prob_saturated(0)
125 m_bv = sdsl::bit_vector_il<BLOCKSIZE>(bv);
126 bv = sdsl::bit_vector();
127 if (!seeds.empty()) {
128 m_ss_val = parse_seeds(m_sseeds);
129 assert(m_sseeds[0].size() == kmer_size);
130 for (
auto itr = m_sseeds.begin(); itr != m_sseeds.end(); ++itr) {
132 assert(m_kmer_size == itr->size());
135 m_rank_support = sdsl::rank_support_il<1>(&m_bv);
136 m_d_size = get_pop();
137 m_data =
new T[m_d_size]();
141 : m_prob_saturated(pow(
double(get_pop_saturated()) /
double(get_pop()),
144#pragma omp parallel for default(none) shared(filter_file_path)
145 for (
unsigned i = 0; i < 2; ++i) {
147 FILE* file = fopen(filter_file_path.c_str(),
"rbe");
149 "MIBloomFilter: File " + filter_file_path +
150 " could not be read.");
153 check_error(fread(&header,
sizeof(
struct FileHeader), 1, file) != 1,
154 "MIBloomFilter: Failed to load header.");
155 log_info(
"MIBloomFilter: Loading header...");
157 const int magic_nine = 9;
158 char magic[magic_nine];
159 const int magic_eight = 8;
160 memcpy(magic, header.magic, magic_eight);
161 magic[magic_eight] =
'\0';
163 log_info(
"MIBloomFilter: Loaded header\nmagic: " + std::string(magic) +
164 "\nhlen: " + std::to_string(header.hlen) +
165 "\nsize: " + std::to_string(header.size) +
166 "\nnhash: " + std::to_string(header.nhash) +
167 "\nkmer: " + std::to_string(header.kmer));
169 m_hash_num = header.nhash;
170 m_kmer_size = header.kmer;
171 m_d_size = header.size;
172 m_data =
new T[m_d_size]();
174 if (header.hlen >
sizeof(
struct FileHeader)) {
176 for (
unsigned i = 0; i < header.nhash; ++i) {
177 char temp[header.kmer];
179 check_error(fread(temp, header.kmer, 1, file) != 1,
180 "MIBloomFilter: Failed to load spaced seed string.");
181 log_info(
"MIBloomFilter: Spaced seed " + std::to_string(i) +
": " +
182 std::string(temp, header.kmer));
183 m_sseeds.push_back(std::string(temp, header.kmer));
186 m_ss_val = parse_seeds(m_sseeds);
187 assert(m_sseeds[0].size() == m_kmer_size);
188 for (
auto itr = m_sseeds.begin(); itr != m_sseeds.end(); ++itr) {
190 assert(m_kmer_size == itr->size());
195 header.hlen != (
sizeof(FileHeader) + m_kmer_size * m_sseeds.size()),
196 "MIBloomFilter: header length: " + std::to_string(header.hlen) +
197 " does not match expected length: " +
198 std::to_string(
sizeof(FileHeader) + m_kmer_size * m_sseeds.size()) +
199 " (likely version mismatch).");
202 "MIBloomFilter: Bloom filter type does not matc.");
204 check_error(header.version != MI_BLOOM_FILTER_VERSION,
205 "MIBloomFilter: Bloom filter version does not match: " +
206 std::to_string(header.version) +
" expected " +
207 std::to_string(MI_BLOOM_FILTER_VERSION) +
".");
209 log_info(
"MIBloomFilter: Loading data vector");
211 long int l_cur_pos = ftell(file);
213 size_t file_size = ftell(file) - header.hlen;
214 fseek(file, l_cur_pos, 0);
217 "MIBloomFilter: " + filter_file_path +
218 " does not match size given by its header. Size: " +
219 std::to_string(file_size) +
" vs " +
220 std::to_string(m_d_size *
sizeof(T)) +
" bytes.");
222 size_t count_read = fread(m_data, file_size, 1, file);
225 "MIBloomFilter: File " + filter_file_path +
226 " could not be read.");
230 std::string bv_filename = filter_file_path +
".sdsl";
231 log_info(
"MIBloomFilter: Loading sdsl interleaved bit vector from: " +
233 load_from_file(m_bv, bv_filename);
234 m_rank_support = sdsl::rank_support_il<1>(&m_bv);
238 log_info(
"MIBloomFilter: Bit vector size: " + std::to_string(m_bv.size()) +
239 "\nPopcount: " + std::to_string(get_pop()));
247 void store(std::string
const& filter_file_path)
const
250#pragma omp parallel for default(none) shared(filter_file_path)
251 for (
unsigned i = 0; i < 2; ++i) {
253 std::ofstream my_file(filter_file_path.c_str(),
254 std::ios::out | std::ios::binary);
257 write_header(my_file);
267 my_file.write(
reinterpret_cast<char*
>(m_data), m_d_size *
sizeof(T));
272 FILE* file = fopen(filter_file_path.c_str(),
"rbe");
274 "MIBloomFilter: " + filter_file_path +
275 " could not be read.");
277 std::string bv_filename = filter_file_path +
".sdsl";
282 store_to_file(m_bv, bv_filename);
306 bool insert(
const uint64_t* hashes,
const bool* strand, T val,
unsigned max)
309 std::vector<unsigned> hash_order;
310 bool saturated =
true;
312 uint64_t rand_value = val;
313 bool strand_dir =
true;
319 for (
unsigned i = 0; i < m_hash_num; ++i) {
321 uint64_t pos = m_rank_support(hashes[i] % m_bv.size());
322 T value = strand_dir ^ strand[i] ? val | STRAND : val;
324 T old_val = m_data[pos];
326 if (old_val > MASK) {
327 old_val = old_val & ANTI_MASK;
332 if (old_val == value) {
335 hash_order.push_back(i);
341 rand_value ^= hashes[i];
343 std::minstd_rand g(rand_value);
344 std::shuffle(hash_order.begin(), hash_order.end(), g);
347 for (
const auto& o : hash_order) {
348 uint64_t pos = m_rank_support(hashes[o] % m_bv.size());
349 T value = strand_dir ^ strand[o] ? val | STRAND : val;
351 T old_val = set_val(&m_data[pos], value);
353 if (old_val > MASK) {
354 old_val = old_val & ANTI_MASK;
384 bool insert(
const uint64_t* hashes, T value,
unsigned max)
387 std::vector<unsigned> hash_order;
389 uint64_t rand_value = value;
391 bool saturated =
true;
394 for (
unsigned i = 0; i < m_hash_num; ++i) {
396 uint64_t pos = m_rank_support(hashes[i] % m_bv.size());
398 T old_val = m_data[pos];
400 if (old_val > MASK) {
401 old_val = old_val & ANTI_MASK;
406 if (old_val == value) {
409 hash_order.push_back(i);
416 rand_value ^= hashes[i];
418 std::minstd_rand g(rand_value);
419 std::shuffle(hash_order.begin(), hash_order.end(), g);
422 for (
const auto& o : hash_order) {
423 uint64_t pos = m_rank_support(hashes[o] % m_bv.size());
425 T old_val = set_val(&m_data[pos], value);
427 if (old_val > MASK) {
428 old_val = old_val & ANTI_MASK;
454 void saturate(
const uint64_t* hashes)
456 for (
unsigned i = 0; i < m_hash_num; ++i) {
457 uint64_t pos = m_rank_support(hashes[i] % m_bv.size());
458 __sync_or_and_fetch(&m_data[pos], MASK);
462 inline std::vector<T> at(
const uint64_t* hashes,
464 unsigned max_miss = 0)
466 std::vector<T> results(m_hash_num);
468 for (
unsigned i = 0; i < m_hash_num; ++i) {
469 uint64_t pos = hashes[i] % m_bv.size();
470 if (m_bv[pos] == 0) {
473 if (misses > max_miss) {
474 return std::vector<T>();
477 uint64_t rank_pos = m_rank_support(pos);
478 T temp_result = m_data[rank_pos];
479 if (temp_result > MASK) {
480 results[i] = m_data[rank_pos] & ANTI_MASK;
482 results[i] = m_data[rank_pos];
494 unsigned at_rank(
const uint64_t* hashes,
495 std::vector<uint64_t>& rank_pos,
496 std::vector<bool>& hits,
497 unsigned max_miss)
const
500 for (
unsigned i = 0; i < m_hash_num; ++i) {
501 uint64_t pos = hashes[i] % m_bv.size();
502 if (
bool(m_bv[pos])) {
503 rank_pos[i] = m_rank_support(pos);
506 if (++misses > max_miss) {
520 bool at_rank(
const uint64_t* hashes, std::vector<uint64_t>& rank_pos)
const
522 for (
unsigned i = 0; i < m_hash_num; ++i) {
523 uint64_t pos = hashes[i] % m_bv.size();
524 if (
bool(m_bv[pos])) {
525 rank_pos[i] = m_rank_support(pos);
533 std::vector<uint64_t> get_rank_pos(
const uint64_t* hashes)
const
535 std::vector<uint64_t> rank_pos(m_hash_num);
536 for (
unsigned i = 0; i < m_hash_num; ++i) {
537 uint64_t pos = hashes[i] % m_bv.size();
538 rank_pos[i] = m_rank_support(pos);
543 uint64_t get_rank_pos(
const uint64_t hash)
const
545 return m_rank_support(hash % m_bv.size());
548 const std::vector<std::vector<unsigned>>& get_seed_values()
const
553 unsigned get_kmer_size()
const {
return m_kmer_size; }
555 unsigned get_hash_num()
const {
return m_hash_num; }
561 size_t get_id_counts(std::vector<size_t>& counts)
const
563 size_t saturated_counts = 0;
564 for (
size_t i = 0; i < m_d_size; ++i) {
565 if (m_data[i] > MASK) {
566 ++counts[m_data[i] & ANTI_MASK];
572 return saturated_counts;
579 size_t get_id_counts_strand(std::vector<size_t>& counts)
const
581 size_t saturated_counts = 0;
582 for (
size_t i = 0; i < m_d_size; ++i) {
583 if (m_data[i] > MASK) {
584 ++counts[m_data[i] & ID_MASK];
587 ++counts[m_data[i] & ANTI_STRAND];
590 return saturated_counts;
593 size_t get_pop()
const
595 size_t index = m_bv.size() - 1;
596 while (m_bv[index] == 0) {
599 return m_rank_support(index) + 1;
606 size_t get_pop_non_zero()
const
609 for (
size_t i = 0; i < m_d_size; ++i) {
610 if (m_data[i] != 0) {
623 T check_values(T max_val)
const
625 for (
size_t i = 0; i < m_d_size; ++i) {
626 if ((m_data[i] & ANTI_MASK) > max_val) {
633 size_t get_pop_saturated()
const
636 for (
size_t i = 0; i < m_d_size; ++i) {
637 if (m_data[i] > MASK) {
644 size_t size()
const {
return m_bv.size(); }
647 void set_data(uint64_t pos, T
id)
651 old_value = m_data[pos];
652 if (old_value > MASK) {
655 }
while (!__sync_bool_compare_and_swap(&m_data[pos], old_value,
id));
659 void saturate_data(uint64_t pos)
666 void set_data_if_empty(uint64_t pos, T
id) { set_val(&m_data[pos],
id); }
668 std::vector<T> get_data(
const std::vector<uint64_t>& rank_pos)
const
670 std::vector<T> results(rank_pos.size());
671 for (
unsigned i = 0; i < m_hash_num; ++i) {
672 results[i] = m_data[rank_pos[i]];
677 T get_data(uint64_t rank)
const {
return m_data[rank]; }
686 double calc_frame_probs(std::vector<double>& frame_probs,
687 unsigned allowed_miss)
689 double occupancy = double(get_pop()) / double(size());
690 std::vector<size_t> count_table =
691 std::vector<size_t>(frame_probs.size(), 0);
692 double sat_prop = double(get_id_counts(count_table));
694 for (
size_t i = 1; i < count_table.size(); ++i) {
695 sum += count_table[i];
697 sat_prop /= double(sum);
698 for (
size_t i = 1; i < count_table.size(); ++i) {
700 calc_prob_single_frame(occupancy,
702 double(count_table[i]) /
double(sum),
715 double calc_frame_probs_strand(std::vector<double>& frame_probs,
716 unsigned allowed_miss)
718 double occupancy = double(get_pop()) / double(size());
719 std::vector<size_t> count_table =
720 std::vector<size_t>(frame_probs.size(), 0);
721 double sat_prop = double(get_id_counts_strand(count_table));
723 for (
const auto& c : count_table) {
726 sat_prop /= double(sum);
727#pragma omp parallel for default(none) shared(count_table)
728 for (
size_t i = 1; i < count_table.size(); ++i) {
730 calc_prob_single_frame(occupancy,
732 double(count_table[i]) /
double(sum),
741 ~MIBloomFilter() {
delete[] m_data; }
746 static bool sort_by_sec(
const std::pair<int, int>& a,
747 const std::pair<int, int>& b)
749 return (a.second < b.second);
755 void write_header(std::ofstream& out)
const
758 const int magic_num = 8;
759 memcpy(header.magic,
"MIBLOOMF", magic_num);
761 header.hlen =
sizeof(
struct FileHeader) + m_kmer_size * m_sseeds.size();
762 header.kmer = m_kmer_size;
763 header.size = m_d_size;
764 header.nhash = m_hash_num;
765 header.version = MI_BLOOM_FILTER_VERSION;
774 out.write(
reinterpret_cast<char*
>(&header),
sizeof(
struct FileHeader));
776 for (
const auto& s : m_sseeds) {
777 out.write(s.c_str(), m_kmer_size);
785 inline static unsigned calc_opti_hash_num(
double fpr)
787 return unsigned(-log(fpr) / log(2));
794 double calc_fpr_num_inserted(
size_t num_entr)
const
796 return pow(1.0 - pow(1.0 - 1.0 /
double(m_bv.size()),
797 double(num_entr) *
double(m_hash_num)),
804 double calc_fpr_hash_num(
int hash_funct_num)
const
806 const double magic = 2.0;
807 return pow(magic, -hash_funct_num);
814 T set_val(T* val, T new_val)
819 if (old_value != 0) {
822 }
while (!__sync_bool_compare_and_swap(val, old_value, new_val));
826 static inline unsigned n_choose_k(
unsigned n,
unsigned k)
838 for (
unsigned i = 2; i <= k; ++i) {
839 result *= (n - i + 1);
848 sdsl::bit_vector_il<BLOCKSIZE> m_bv;
850 sdsl::rank_support_il<1> m_rank_support;
853 unsigned m_kmer_size;
855 using seed_val = std::vector<std::vector<unsigned>>;
856 std::vector<std::string> m_sseeds;
858 double m_prob_saturated;
861 static const uint32_t MI_BLOOM_FILTER_VERSION = 1;
Definition: mi_bloom_filter.hpp:28
void store(std::string const &filter_file_path) const
Definition: mi_bloom_filter.hpp:247
Definition: bloom_filter.hpp:16
void check_error(bool condition, const std::string &msg)
void log_info(const std::string &msg)