1#ifndef KAORI_MISMATCH_TRIE_HPP
2#define KAORI_MISMATCH_TRIE_HPP
54int trie_base_shift() {
55 if constexpr(base_ ==
'A') {
57 }
else if constexpr(base_ ==
'C') {
59 }
else if constexpr(base_ ==
'G') {
68 MismatchTrie() =
default;
70 MismatchTrie(SeqLength barcode_length, DuplicateAction duplicates) :
71 my_length(barcode_length),
72 my_duplicates(duplicates),
79 std::vector<BarcodeIndex> my_pointers;
83 auto current = my_pointers[node];
84 if (current == STATUS_UNMATCHED) {
85 if (my_pointers.size() > std::numeric_limits<BarcodeIndex>::max()) {
86 throw std::runtime_error(
"integer overflow for trie nodes");
88 current = my_pointers.size();
89 my_pointers[node] = current;
90 my_pointers.insert(my_pointers.end(), NUM_BASES, STATUS_UNMATCHED);
95 void end(BarcodeIndex node, TrieAddStatus& status) {
96 auto& current = my_pointers[node];
98 if (current == STATUS_UNMATCHED) {
100 }
else if (current == STATUS_AMBIGUOUS) {
101 status.is_duplicate =
true;
103 status.is_duplicate =
true;
104 switch(my_duplicates) {
105 case DuplicateAction::FIRST:
107 case DuplicateAction::LAST:
108 status.duplicate_replaced =
true;
109 current = my_counter;
111 case DuplicateAction::NONE:
112 status.duplicate_cleared =
true;
115 case DuplicateAction::ERROR:
116 throw std::runtime_error(
"duplicate sequences detected (" +
117 std::to_string(current + 1) +
", " +
118 std::to_string(my_counter + 1) +
") when constructing the trie");
124 void process_ambiguous(SeqLength i, BarcodeIndex node,
const char* barcode_seq, TrieAddStatus& status) {
125 node += trie_base_shift<base_>();
127 if (i == my_length) {
131 recursive_add(i, node, barcode_seq, status);
135 void recursive_add(SeqLength i, BarcodeIndex node,
const char* barcode_seq, TrieAddStatus& status) {
139 switch (barcode_seq[i]) {
141 node += trie_base_shift<'A'>();
break;
143 node += trie_base_shift<'C'>();
break;
145 node += trie_base_shift<'G'>();
break;
147 node += trie_base_shift<'T'>();
break;
151 if ((++i) == my_length) {
161 status.has_ambiguous =
true;
163 auto processA = [&]() ->
void { process_ambiguous<'A'>(i, node, barcode_seq, status); };
164 auto processC = [&]() ->
void { process_ambiguous<'C'>(i, node, barcode_seq, status); };
165 auto processG = [&]() ->
void { process_ambiguous<'G'>(i, node, barcode_seq, status); };
166 auto processT = [&]() ->
void { process_ambiguous<'T'>(i, node, barcode_seq, status); };
168 switch(barcode_seq[i]) {
170 processA(); processG();
break;
172 processC(); processT();
break;
174 processC(); processG();
break;
176 processA(); processT();
break;
178 processG(); processT();
break;
180 processA(); processC();
break;
182 processC(); processG(); processT();
break;
184 processA(); processG(); processT();
break;
186 processA(); processC(); processT();
break;
188 processA(); processC(); processG();
break;
190 processA(); processC(); processG(); processT();
break;
192 throw std::runtime_error(
"unknown base '" + std::string(1, barcode_seq[i]) +
"' detected when constructing the trie");
197 TrieAddStatus add(
const char* barcode_seq) {
198 TrieAddStatus status;
199 recursive_add(0, 0, barcode_seq, status);
212 const std::vector<BarcodeIndex>& pointers()
const {
218 template<
class SearchResult_>
219 void replace_best_with_chosen(SearchResult_& best, BarcodeIndex& best_index,
int best_score,
const SearchResult_& chosen, BarcodeIndex chosen_index,
int chosen_score)
const {
221 if (chosen_score < best_score) {
223 }
else if (chosen_score == best_score) {
224 if (chosen_index != best_index) {
225 if (my_duplicates == DuplicateAction::FIRST) {
226 if (chosen_index < best_index) {
227 best_index = chosen_index;
229 }
else if (my_duplicates == DuplicateAction::LAST) {
230 if (chosen_index > best_index) {
231 best_index = chosen_index;
239 }
else if (chosen_index == STATUS_AMBIGUOUS) {
240 if (chosen_score < best_score) {
242 }
else if (chosen_score == best_score) {
252 void scan_final_position_with_mismatch(BarcodeIndex node,
int refshift, BarcodeIndex& current_index,
int current_mismatches,
int& mismatch_cap)
const {
254 for (
int s = 0; s < NUM_BASES; ++s) {
259 auto candidate = my_pointers[node + s];
262 if (candidate != current_index) {
263 if (my_duplicates == DuplicateAction::FIRST) {
264 if (current_index > candidate) {
265 current_index = candidate;
267 }
else if (my_duplicates == DuplicateAction::LAST) {
268 if (current_index < candidate) {
269 current_index = candidate;
277 current_index = candidate;
278 mismatch_cap = current_mismatches;
282 }
else if (candidate == STATUS_AMBIGUOUS) {
291 mismatch_cap = current_mismatches;
300 if (!is_optimal(0, 0, maxed)) {
301 std::vector<BarcodeIndex> replacement;
302 replacement.reserve(my_pointers.size());
303 optimize(0, 0, replacement);
304 my_pointers.swap(replacement);
312 bool is_optimal(SeqLength i, BarcodeIndex node, BarcodeIndex& maxed)
const {
315 for (
int s = 0; s < NUM_BASES; ++s) {
316 auto v = my_pointers[node + s];
326 if (!is_optimal(i, v, maxed)) {
334 void optimize(SeqLength i, BarcodeIndex node, std::vector<BarcodeIndex>& trie)
const {
335 auto it = my_pointers.begin() + node;
337 trie.insert(trie.end(), it, it + NUM_BASES);
341 for (
int s = 0; s < NUM_BASES; ++s) {
342 auto& v = trie[new_node + s];
349 optimize(i, original, trie);
355inline std::pair<BarcodeIndex, int> trie_next_base(
char base,
BarcodeIndex node,
const std::vector<BarcodeIndex>& pointers) {
360 shift = trie_base_shift<'A'>(); current = pointers[node + shift];
break;
362 shift = trie_base_shift<'C'>(); current = pointers[node + shift];
break;
364 shift = trie_base_shift<'G'>(); current = pointers[node + shift];
break;
366 shift = trie_base_shift<'T'>(); current = pointers[node + shift];
break;
370 return std::make_pair(current, shift);
398 MismatchTrie my_core;
410 return my_core.add(barcode_seq);
417 return my_core.length();
424 return my_core.size();
471 return search(search_seq, 0, 0, 0, max_mismatches);
476 const auto& pointers = my_core.pointers();
477 auto next = trie_next_base(seq[i], node, pointers);
478 auto current = next.first;
479 auto shift = next.second;
489 max_mismatches = mismatches;
490 return Result(current, mismatches);
495 if (mismatches <= max_mismatches) {
496 my_core.scan_final_position_with_mismatch(node, shift, alt, mismatches, max_mismatches);
499 return Result(alt, mismatches);
504 best =
search(seq, i, current, mismatches, max_mismatches);
508 if (mismatches <= max_mismatches) {
509 for (
int s = 0; s < NUM_BASES; ++s) {
514 auto alt = pointers[node + s];
519 if (mismatches <= max_mismatches) {
520 auto chosen =
search(seq, i, alt, mismatches, max_mismatches);
521 my_core.replace_best_with_chosen(best, best.index, best.mismatches, chosen, chosen.index, chosen.mismatches);
543template<
int num_segments_>
558 my_core(std::accumulate(segments.begin(), segments.end(), 0), duplicates),
559 my_boundaries(segments)
561 for (
int i = 1; i < num_segments_; ++i) {
562 my_boundaries[i] += my_boundaries[i-1];
567 MismatchTrie my_core;
568 std::array<SeqLength, num_segments_> my_boundaries;
580 return my_core.add(barcode_seq);
587 return my_core.length();
594 return my_core.size();
649 Result search(
const char* search_seq,
const std::array<int, num_segments_>& max_mismatches)
const {
650 int total_mismatches = std::accumulate(max_mismatches.begin(), max_mismatches.end(), 0);
651 return search(search_seq, 0, 0,
Result(), max_mismatches, total_mismatches);
655 Result
search(
const char* seq,
SeqLength i,
BarcodeIndex segment_id, Result state,
const std::array<int, num_segments_>& segment_mismatches,
int& total_mismatches)
const {
658 auto node = state.
index;
660 const auto& pointers = my_core.pointers();
661 auto next = trie_next_base(seq[i], node, pointers);
662 auto current = next.first;
663 auto shift = next.second;
673 total_mismatches = state.mismatches;
674 state.index = current;
680 auto& current_segment_mm = state.per_segment[segment_id];
681 ++current_segment_mm;
683 if (state.mismatches <= total_mismatches && current_segment_mm <= segment_mismatches[segment_id]) {
684 my_core.scan_final_position_with_mismatch(node, shift, state.index, state.mismatches, total_mismatches);
690 auto next_segment_id = segment_id;
691 if (i == my_boundaries[segment_id]) {
697 best.mismatches = total_mismatches + 1;
700 state.index = current;
701 best =
search(seq, i, next_segment_id, state, segment_mismatches, total_mismatches);
705 auto& current_segment_mm = state.per_segment[segment_id];
706 ++current_segment_mm;
708 if (state.mismatches <= total_mismatches && current_segment_mm <= segment_mismatches[segment_id]) {
709 for (
int s = 0; s < NUM_BASES; ++s) {
714 auto alt = pointers[node + s];
719 if (state.mismatches <= total_mismatches) {
721 auto chosen =
search(seq, i, next_segment_id, state, segment_mismatches, total_mismatches);
722 my_core.replace_best_with_chosen(best, best.index, best.mismatches, chosen, chosen.index, chosen.mismatches);
Search for barcodes with mismatches anywhere.
Definition MismatchTrie.hpp:383
BarcodeIndex size() const
Definition MismatchTrie.hpp:423
Result search(const char *search_seq, int max_mismatches) const
Definition MismatchTrie.hpp:470
void optimize()
Definition MismatchTrie.hpp:431
SeqLength length() const
Definition MismatchTrie.hpp:416
AnyMismatches(SeqLength barcode_length, DuplicateAction duplicates)
Definition MismatchTrie.hpp:395
TrieAddStatus add(const char *barcode_seq)
Definition MismatchTrie.hpp:409
Search for barcodes with segmented mismatches.
Definition MismatchTrie.hpp:544
BarcodeIndex size() const
Definition MismatchTrie.hpp:593
SeqLength length() const
Definition MismatchTrie.hpp:586
TrieAddStatus add(const char *barcode_seq)
Definition MismatchTrie.hpp:579
SegmentedMismatches()=default
SegmentedMismatches(std::array< SeqLength, num_segments_ > segments, DuplicateAction duplicates)
Definition MismatchTrie.hpp:557
void optimize()
Definition MismatchTrie.hpp:601
Result search(const char *search_seq, const std::array< int, num_segments_ > &max_mismatches) const
Definition MismatchTrie.hpp:649
Namespace for the kaori barcode-matching library.
Definition BarcodePool.hpp:16
std::size_t SeqLength
Definition utils.hpp:37
constexpr BarcodeIndex STATUS_AMBIGUOUS
Definition utils.hpp:53
constexpr BarcodeIndex STATUS_UNMATCHED
Definition utils.hpp:48
DuplicateAction
Definition utils.hpp:26
bool is_barcode_index_ok(BarcodeIndex index)
Definition utils.hpp:60
std::vector< constchar * >::size_type BarcodeIndex
Definition utils.hpp:43
Results of search().
Definition MismatchTrie.hpp:439
BarcodeIndex index
Definition MismatchTrie.hpp:453
int mismatches
Definition MismatchTrie.hpp:460
Result of the segmented search.
Definition MismatchTrie.hpp:609
int mismatches
Definition MismatchTrie.hpp:632
BarcodeIndex index
Definition MismatchTrie.hpp:625
std::array< int, num_segments_ > per_segment
Definition MismatchTrie.hpp:639
Status of barcode sequence addition to the trie.
Definition MismatchTrie.hpp:26
bool duplicate_replaced
Definition MismatchTrie.hpp:41
bool duplicate_cleared
Definition MismatchTrie.hpp:47
bool has_ambiguous
Definition MismatchTrie.hpp:30
bool is_duplicate
Definition MismatchTrie.hpp:35
Utilites for sequence matching.