kaori
A C++ library for barcode extraction and matching
Loading...
Searching...
No Matches
MismatchTrie.hpp
Go to the documentation of this file.
1#ifndef KAORI_MISMATCH_TRIE_HPP
2#define KAORI_MISMATCH_TRIE_HPP
3
4#include <array>
5#include <vector>
6#include <stdexcept>
7#include <numeric>
8#include "utils.hpp"
9
16namespace kaori {
17
26public:
30 static constexpr int STATUS_MISSING = -1;
31
35 static constexpr int STATUS_AMBIGUOUS = -2;
36
37protected:
41 static constexpr int NUM_BASES = 4;
46public:
52
57 MismatchTrie(size_t barcode_length, DuplicateAction duplicates) :
58 length(barcode_length),
59 pointers(NUM_BASES, STATUS_MISSING),
60 duplicates(duplicates),
61 counter(0)
62 {}
63
64public:
68 struct AddStatus {
72 bool has_ambiguous = false;
73
77 bool is_duplicate = false;
78
83 bool duplicate_replaced = false;
84
89 bool duplicate_cleared = false;
90 };
91
92private:
93 void next(int shift, int& position) {
94 auto& current = pointers[position + shift];
95 if (current < 0) {
96 current = pointers.size();
97 position = current;
98 pointers.resize(position + NUM_BASES, STATUS_MISSING);
99 } else {
100 position = current;
101 }
102 }
103
104 void end(int shift, int position, AddStatus& status) {
105 auto& current = pointers[position + shift];
106 if (current >= 0) {
107 status.is_duplicate = true;
108 switch(duplicates) {
109 case DuplicateAction::FIRST:
110 break;
111 case DuplicateAction::LAST:
112 status.duplicate_replaced = true;
113 current = counter;
114 break;
115 case DuplicateAction::NONE:
116 status.duplicate_cleared = true;
117 current = STATUS_AMBIGUOUS;
118 break;
119 case DuplicateAction::ERROR:
120 throw std::runtime_error("duplicate sequences detected (" +
121 std::to_string(current + 1) + ", " +
122 std::to_string(counter + 1) + ") when constructing the trie");
123 }
124
125 } else if (current == STATUS_MISSING) {
126 current = counter;
127 } else if (current == STATUS_AMBIGUOUS) {
128 status.is_duplicate = true;
129 }
130 }
131
132 void recursive_add(size_t i, int position, const char* barcode_seq, AddStatus& status) {
133 // Processing a stretch of non-ambiguous codes, where possible.
134 // This reduces the recursion depth among the (hopefully fewer) ambiguous codes.
135 while (1) {
136 auto shift = base_shift<true>(barcode_seq[i]);
137 if (shift == NON_STANDARD_BASE) {
138 break;
139 }
140
141 if ((++i) == length) {
142 end(shift, position, status);
143 return;
144 } else {
145 next(shift, position);
146 }
147 }
148
149 // Processing the ambiguous codes.
150 status.has_ambiguous = true;
151
152 auto process = [&](char base) -> void {
153 auto shift = base_shift(base);
154 if (i + 1 == length) {
155 end(shift, position, status);
156 } else {
157 auto curpos = position;
158 next(shift, curpos);
159 recursive_add(i + 1, curpos, barcode_seq, status);
160 }
161 };
162
163 switch(barcode_seq[i]) {
164 case 'R': case 'r':
165 process('A'); process('G'); break;
166 case 'Y': case 'y':
167 process('C'); process('T'); break;
168 case 'S': case 's':
169 process('C'); process('G'); break;
170 case 'W': case 'w':
171 process('A'); process('T'); break;
172 case 'K': case 'k':
173 process('G'); process('T'); break;
174 case 'M': case 'm':
175 process('A'); process('C'); break;
176 case 'B': case 'b':
177 process('C'); process('G'); process('T'); break;
178 case 'D': case 'd':
179 process('A'); process('G'); process('T'); break;
180 case 'H': case 'h':
181 process('A'); process('C'); process('T'); break;
182 case 'V': case 'v':
183 process('A'); process('C'); process('G'); break;
184 case 'N': case 'n':
185 process('A'); process('C'); process('G'); process('T'); break;
186 default:
187 throw std::runtime_error("unknown base '" + std::string(1, barcode_seq[i]) + "' detected when constructing the trie");
188 }
189 }
190
191public:
200 AddStatus add(const char* barcode_seq) {
201 AddStatus status;
202 recursive_add(0, 0, barcode_seq, status);
203 ++counter;
204 return status;
205 }
206
207public:
211 size_t get_length() const {
212 return length;
213 }
214
218 int size() const {
219 return counter;
220 }
221
222protected:
226 size_t length;
227 std::vector<int> pointers;
228
229 static constexpr int NON_STANDARD_BASE = -1;
230
231 template<bool allow_unknown = false>
232 static int base_shift(char base) {
233 int shift = 0;
234 switch (base) {
235 case 'A': case 'a':
236 break;
237 case 'C': case 'c':
238 shift = 1;
239 break;
240 case 'G': case 'g':
241 shift = 2;
242 break;
243 case 'T': case 't':
244 shift = 3;
245 break;
246 default:
247 if constexpr(allow_unknown) {
248 shift = NON_STANDARD_BASE;
249 } else {
250 throw std::runtime_error("unknown base '" + std::string(1, base) + "' detected when constructing the trie");
251 }
252 }
253 return shift;
254 }
259protected:
263 DuplicateAction duplicates;
264
265 // To be called in the middle steps of the recursive search (i.e., for all but the last position).
266 template<class SearchResult_>
267 void replace_best_with_chosen(SearchResult_& best, int& best_index, int best_score, const SearchResult_& chosen, int chosen_index, int chosen_score) const {
268 if (chosen_index >= 0) {
269 if (chosen_score < best_score) {
270 best = chosen;
271 } else if (chosen_score == best_score) {
272 if (chosen_index != best_index) { // protect against multiple occurrences of IUPAC code-containing barcodes.
273 if (duplicates == DuplicateAction::FIRST) {
274 if (chosen_index < best_index) {
275 best_index = chosen_index;
276 }
277 } else if (duplicates == DuplicateAction::LAST) {
278 if (chosen_index > best_index) {
279 best_index = chosen_index;
280 }
281 } else {
282 best_index = STATUS_AMBIGUOUS;
283 }
284 }
285 }
286
287 } else if (chosen_index == STATUS_AMBIGUOUS) {
288 if (chosen_score < best_score) {
289 best = chosen;
290 } else if (chosen_score == best_score) {
291 // Ambiguity is infectious. Each ambiguous status indicates that there
292 // are already 2+ barcodes on this score, so it doesn't matter how
293 // many other unambiguous barcodes are here; we're already ambiguous.
294 best_index = STATUS_AMBIGUOUS;
295 }
296 }
297 }
298
299 // To be called in the last step of the recursive search.
300 void scan_final_position_with_mismatch(int node, int refshift, int& current_index, int current_mismatches, int& mismatch_cap) const {
301 bool found = false;
302 for (int s = 0; s < NUM_BASES; ++s) {
303 if (s == refshift) {
304 continue;
305 }
306
307 int candidate = pointers[node + s];
308 if (candidate >= 0) {
309 if (found) {
310 if (candidate != current_index) { // protect against multiple occurrences of IUPAC-containg barcodes.
311 if (duplicates == DuplicateAction::FIRST) {
312 if (current_index > candidate) {
313 current_index = candidate;
314 }
315 } else if (duplicates == DuplicateAction::LAST) {
316 if (current_index < candidate) {
317 current_index = candidate;
318 }
319 } else {
320 current_index = STATUS_AMBIGUOUS; // ambiguous, so we quit early.
321 break;
322 }
323 }
324 } else {
325 current_index = candidate;
326 mismatch_cap = current_mismatches;
327 found = true;
328 }
329
330 } else if (candidate == STATUS_AMBIGUOUS) {
331 // If an ambiguity is present on a base at the last position,
332 // and we're accepting a mismatch on the last position, then
333 // we already have at least two known barcodes that match the
334 // input sequence. The behavior of the other bases is irrelevant;
335 // even if they refer to non-ambiguous barcodes, that just adds to the
336 // set of 2+ barcodes that the input sequence already matches.
337 // So, we have no choice but to fail the match due to ambiguity.
338 current_index = STATUS_AMBIGUOUS;
339 mismatch_cap = current_mismatches;
340 break;
341 }
342 }
343 }
348private:
349 int counter;
350
351public:
356 void optimize() {
357 int maxed = 0;
358 if (!is_optimal(0, 0, maxed)) {
359 std::vector<int> replacement;
360 replacement.reserve(pointers.size());
361 optimize(0, 0, replacement);
362 pointers.swap(replacement);
363 }
364 }
365
366private:
367 // Optimization involves reorganizing the nodes so that the pointers are
368 // always increasing. This promotes memory locality of similar sequences
369 // in a depth-first search (which is what search() does anyway).
370 bool is_optimal(int node, size_t pos, int& maxed) const {
371 ++pos;
372 if (pos < length) {
373 for (int s = 0; s < NUM_BASES; ++s) {
374 auto v = pointers[node + s];
375 if (v < 0) {
376 continue;
377 }
378
379 if (v < maxed) {
380 return false;
381 }
382
383 maxed = v;
384 if (!is_optimal(v, pos, maxed)) {
385 return false;
386 }
387 }
388 }
389 return true;
390 }
391
392 void optimize(int node, size_t pos, std::vector<int>& trie) const {
393 auto it = pointers.begin() + node;
394 size_t new_node = trie.size();
395 trie.insert(trie.end(), it, it + NUM_BASES);
396
397 ++pos;
398 if (pos < length) {
399 for (int s = 0; s < NUM_BASES; ++s) {
400 auto& v = trie[new_node + s];
401 if (v < 0) {
402 continue;
403 }
404
405 auto original = v;
406 v = trie.size();
407 optimize(original, pos, trie);
408 }
409 }
410 }
411};
412
420public:
426
431 AnyMismatches(size_t barcode_length, DuplicateAction duplicates) : MismatchTrie(barcode_length, duplicates) {}
432
433public:
446 std::pair<int, int> search(const char* search_seq, int max_mismatches) const {
447 return search(search_seq, 0, 0, 0, max_mismatches);
448 }
449
450private:
451 std::pair<int, int> search(const char* seq, size_t pos, int node, int mismatches, int& max_mismatches) const {
452 int shift = base_shift<true>(seq[pos]);
453 int current = (shift >= 0 ? pointers[node + shift] : STATUS_MISSING);
454
455 // At the end: we prepare to return the actual values. We also refine
456 // the max number of mismatches so that we don't search for things with
457 // more mismatches than the best hit that was already encountered.
458 if (pos + 1 == length) {
459 if (current >= 0 || current == STATUS_AMBIGUOUS) {
460 max_mismatches = mismatches; // this assignment should always decrease max_mismatches, otherwise the search would have terminated earlier.
461 return std::make_pair(current, mismatches);
462 }
463
464 int alt = STATUS_MISSING;
465 ++mismatches;
466 if (mismatches <= max_mismatches) {
467 scan_final_position_with_mismatch(node, shift, alt, mismatches, max_mismatches);
468 }
469
470 return std::make_pair(alt, mismatches);
471
472 } else {
473 ++pos;
474
475 std::pair<int, int> best(STATUS_MISSING, max_mismatches + 1);
476 if (current >= 0) {
477 best = search(seq, pos, current, mismatches, max_mismatches);
478 }
479
480 ++mismatches;
481 if (mismatches <= max_mismatches) {
482 for (int s = 0; s < NUM_BASES; ++s) {
483 if (shift == s) {
484 continue;
485 }
486
487 int alt = pointers[node + s];
488 if (alt < 0) {
489 continue;
490 }
491
492 if (mismatches <= max_mismatches) { // check again, just in case max_mismatches changed.
493 auto chosen = search(seq, pos, alt, mismatches, max_mismatches);
494 replace_best_with_chosen(best, best.first, best.second, chosen, chosen.first, chosen.second);
495 }
496 }
497 }
498
499 return best;
500 }
501 }
502};
503
513template<size_t num_segments>
515public:
521
527 SegmentedMismatches(std::array<int, num_segments> segments, DuplicateAction duplicates) :
528 MismatchTrie(std::accumulate(segments.begin(), segments.end(), 0), duplicates),
529 boundaries(segments)
530 {
531 for (size_t i = 1; i < num_segments; ++i) {
532 boundaries[i] += boundaries[i-1];
533 }
534 }
535
536public:
540 struct Result {
544 Result() : per_segment() {}
553 int index = 0;
554
558 int total = 0;
559
563 std::array<int, num_segments> per_segment;
564 };
565
577 Result search(const char* search_seq, const std::array<int, num_segments>& max_mismatches) const {
578 int total_mismatches = std::accumulate(max_mismatches.begin(), max_mismatches.end(), 0);
579 return search(search_seq, 0, 0, Result(), max_mismatches, total_mismatches);
580 }
581
582private:
583 Result search(
584 const char* seq,
585 size_t pos,
586 size_t segment_id,
587 Result state,
588 const std::array<int, num_segments>& segment_mismatches,
589 int& total_mismatches
590 ) const {
591 // Note that, during recursion, state.index does double duty
592 // as the index of the node on the trie.
593 int node = state.index;
594
595 int shift = base_shift<true>(seq[pos]);
596 int current = (shift >= 0 ? pointers[node + shift] : STATUS_MISSING);
597
598 // At the end: we prepare to return the actual values. We also refine
599 // the max number of mismatches so that we don't search for things with
600 // more mismatches than the best hit that was already encountered.
601 if (pos + 1 == length) {
602 if (current >= 0 || current == STATUS_AMBIGUOUS) {
603 total_mismatches = state.total; // this assignment should always decrease total_mismatches, otherwise the search would have terminated earlier.
604 state.index = current;
605 return state;
606 }
607
608 state.index = STATUS_MISSING;
609 ++state.total;
610 auto& current_segment_mm = state.per_segment[segment_id];
611 ++current_segment_mm;
612
613 if (state.total <= total_mismatches && current_segment_mm <= segment_mismatches[segment_id]) {
614 scan_final_position_with_mismatch(node, shift, state.index, state.total, total_mismatches);
615 }
616
617 return state;
618
619 } else {
620 auto next_pos = pos + 1;
621 auto next_segment_id = segment_id;
622 if (static_cast<int>(next_pos) == boundaries[segment_id]) { // TODO: boundaries should probably be size_t's, thus avoiding the need for this cast.
623 ++next_segment_id;
624 }
625
626 Result best;
627 best.index = STATUS_MISSING;
628 best.total = total_mismatches + 1;
629
630 if (current >= 0) {
631 state.index = current;
632 best = search(seq, next_pos, next_segment_id, state, segment_mismatches, total_mismatches);
633 }
634
635 ++state.total;
636 auto& current_segment_mm = state.per_segment[segment_id];
637 ++current_segment_mm;
638
639 if (state.total <= total_mismatches && current_segment_mm <= segment_mismatches[segment_id]) {
640 for (int s = 0; s < NUM_BASES; ++s) {
641 if (shift == s) {
642 continue;
643 }
644
645 int alt = pointers[node + s];
646 if (alt < 0) {
647 continue;
648 }
649
650 if (state.total <= total_mismatches) { // check again, just in case total_mismatches changed.
651 state.index = alt;
652 auto chosen = search(seq, next_pos, next_segment_id, state, segment_mismatches, total_mismatches);
653 replace_best_with_chosen(best, best.index, best.total, chosen, chosen.index, chosen.total);
654 }
655 }
656 }
657
658 return best;
659 }
660 }
661private:
662 std::array<int, num_segments> boundaries;
663};
664
665}
666
667#endif
Search for barcodes with mismatches anywhere.
Definition MismatchTrie.hpp:419
AnyMismatches(size_t barcode_length, DuplicateAction duplicates)
Definition MismatchTrie.hpp:431
AnyMismatches()
Definition MismatchTrie.hpp:425
std::pair< int, int > search(const char *search_seq, int max_mismatches) const
Definition MismatchTrie.hpp:446
Base class for the mismatch search.
Definition MismatchTrie.hpp:25
void optimize()
Definition MismatchTrie.hpp:356
int size() const
Definition MismatchTrie.hpp:218
size_t get_length() const
Definition MismatchTrie.hpp:211
MismatchTrie()
Definition MismatchTrie.hpp:51
static constexpr int STATUS_MISSING
Definition MismatchTrie.hpp:30
AddStatus add(const char *barcode_seq)
Definition MismatchTrie.hpp:200
MismatchTrie(size_t barcode_length, DuplicateAction duplicates)
Definition MismatchTrie.hpp:57
static constexpr int STATUS_AMBIGUOUS
Definition MismatchTrie.hpp:35
Search for barcodes with segmented mismatches.
Definition MismatchTrie.hpp:514
SegmentedMismatches(std::array< int, num_segments > segments, DuplicateAction duplicates)
Definition MismatchTrie.hpp:527
SegmentedMismatches()
Definition MismatchTrie.hpp:520
Result search(const char *search_seq, const std::array< int, num_segments > &max_mismatches) const
Definition MismatchTrie.hpp:577
Namespace for the kaori barcode-matching library.
Definition BarcodePool.hpp:13
Status of the barcode sequence addition.
Definition MismatchTrie.hpp:68
bool duplicate_cleared
Definition MismatchTrie.hpp:89
bool duplicate_replaced
Definition MismatchTrie.hpp:83
bool has_ambiguous
Definition MismatchTrie.hpp:72
bool is_duplicate
Definition MismatchTrie.hpp:77
Result of the segmented search.
Definition MismatchTrie.hpp:540
std::array< int, num_segments > per_segment
Definition MismatchTrie.hpp:563
int index
Definition MismatchTrie.hpp:553
int total
Definition MismatchTrie.hpp:558