kaori
A C++ library for barcode extraction and matching
Loading...
Searching...
No Matches
MismatchTrie.hpp
Go to the documentation of this file.
1#ifndef KAORI_MISMATCH_TRIE_HPP
2#define KAORI_MISMATCH_TRIE_HPP
3
4#include <array>
5#include <vector>
6#include <stdexcept>
7#include <numeric>
8#include <cstddef>
9#include <limits>
10
11#include "utils.hpp"
12
19namespace kaori {
20
30 bool has_ambiguous = false;
31
35 bool is_duplicate = false;
36
41 bool duplicate_replaced = false;
42
47 bool duplicate_cleared = false;
48};
49
53template<char base_>
54int trie_base_shift() {
55 if constexpr(base_ == 'A') {
56 return 0;
57 } else if constexpr(base_ == 'C') {
58 return 1;
59 } else if constexpr(base_ == 'G') {
60 return 2;
61 } else { // i.e., base_ == 'T'
62 return 3;
63 }
64}
65
66class MismatchTrie {
67public:
68 MismatchTrie() = default;
69
70 MismatchTrie(SeqLength barcode_length, DuplicateAction duplicates) :
71 my_length(barcode_length),
72 my_duplicates(duplicates),
73 my_pointers(NUM_BASES, STATUS_UNMATCHED)
74 {}
75
76private:
77 SeqLength my_length;
78 DuplicateAction my_duplicates;
79 std::vector<BarcodeIndex> my_pointers;
80 BarcodeIndex my_counter = 0;
81
82 BarcodeIndex next(BarcodeIndex node) {
83 auto current = my_pointers[node]; // don't make this a reference as it gets invalidated by the resize.
84 if (current == STATUS_UNMATCHED) {
85 if (my_pointers.size() > std::numeric_limits<BarcodeIndex>::max()) { // this should never happen on 64-bit machines, but you never know.
86 throw std::runtime_error("integer overflow for trie nodes");
87 }
88 current = my_pointers.size();
89 my_pointers[node] = current;
90 my_pointers.insert(my_pointers.end(), NUM_BASES, STATUS_UNMATCHED); // this should throw a bad_alloc if we exceed the vector size limits.
91 }
92 return current;
93 }
94
95 void end(BarcodeIndex node, TrieAddStatus& status) {
96 auto& current = my_pointers[node];
97
98 if (current == STATUS_UNMATCHED) {
99 current = my_counter;
100 } else if (current == STATUS_AMBIGUOUS) {
101 status.is_duplicate = true;
102 } else {
103 status.is_duplicate = true;
104 switch(my_duplicates) {
105 case DuplicateAction::FIRST:
106 break;
107 case DuplicateAction::LAST:
108 status.duplicate_replaced = true;
109 current = my_counter;
110 break;
111 case DuplicateAction::NONE:
112 status.duplicate_cleared = true;
113 current = STATUS_AMBIGUOUS;
114 break;
115 case DuplicateAction::ERROR:
116 throw std::runtime_error("duplicate sequences detected (" +
117 std::to_string(current + 1) + ", " +
118 std::to_string(my_counter + 1) + ") when constructing the trie");
119 }
120 }
121 }
122
123 template<char base_>
124 void process_ambiguous(SeqLength i, BarcodeIndex node, const char* barcode_seq, TrieAddStatus& status) {
125 node += trie_base_shift<base_>();
126 ++i;
127 if (i == my_length) {
128 end(node, status);
129 } else {
130 node = next(node);
131 recursive_add(i, node, barcode_seq, status);
132 }
133 }
134
135 void recursive_add(SeqLength i, BarcodeIndex node, const char* barcode_seq, TrieAddStatus& status) {
136 // Processing a stretch of non-ambiguous codes, where possible.
137 // This reduces the recursion depth among the (hopefully fewer) ambiguous codes.
138 while (1) {
139 switch (barcode_seq[i]) {
140 case 'A': case 'a':
141 node += trie_base_shift<'A'>(); break;
142 case 'C': case 'c':
143 node += trie_base_shift<'C'>(); break;
144 case 'G': case 'g':
145 node += trie_base_shift<'G'>(); break;
146 case 'T': case 't':
147 node += trie_base_shift<'T'>(); break;
148 default:
149 goto ambiguous;
150 }
151 if ((++i) == my_length) {
152 end(node, status);
153 return;
154 } else {
155 node = next(node);
156 }
157 }
158
159ambiguous:
160 // Processing the ambiguous codes.
161 status.has_ambiguous = true;
162
163 auto processA = [&]() -> void { process_ambiguous<'A'>(i, node, barcode_seq, status); };
164 auto processC = [&]() -> void { process_ambiguous<'C'>(i, node, barcode_seq, status); };
165 auto processG = [&]() -> void { process_ambiguous<'G'>(i, node, barcode_seq, status); };
166 auto processT = [&]() -> void { process_ambiguous<'T'>(i, node, barcode_seq, status); };
167
168 switch(barcode_seq[i]) {
169 case 'R': case 'r':
170 processA(); processG(); break;
171 case 'Y': case 'y':
172 processC(); processT(); break;
173 case 'S': case 's':
174 processC(); processG(); break;
175 case 'W': case 'w':
176 processA(); processT(); break;
177 case 'K': case 'k':
178 processG(); processT(); break;
179 case 'M': case 'm':
180 processA(); processC(); break;
181 case 'B': case 'b':
182 processC(); processG(); processT(); break;
183 case 'D': case 'd':
184 processA(); processG(); processT(); break;
185 case 'H': case 'h':
186 processA(); processC(); processT(); break;
187 case 'V': case 'v':
188 processA(); processC(); processG(); break;
189 case 'N': case 'n':
190 processA(); processC(); processG(); processT(); break;
191 default:
192 throw std::runtime_error("unknown base '" + std::string(1, barcode_seq[i]) + "' detected when constructing the trie");
193 }
194 }
195
196public:
197 TrieAddStatus add(const char* barcode_seq) {
198 TrieAddStatus status;
199 recursive_add(0, 0, barcode_seq, status);
200 ++my_counter;
201 return status;
202 }
203
204 SeqLength length() const {
205 return my_length;
206 }
207
208 BarcodeIndex size() const {
209 return my_counter;
210 }
211
212 const std::vector<BarcodeIndex>& pointers() const {
213 return my_pointers;
214 }
215
216public:
217 // To be called in the middle steps of the recursive search (i.e., for all but the last position).
218 template<class SearchResult_>
219 void replace_best_with_chosen(SearchResult_& best, BarcodeIndex& best_index, int best_score, const SearchResult_& chosen, BarcodeIndex chosen_index, int chosen_score) const {
220 if (is_barcode_index_ok(chosen_index)) {
221 if (chosen_score < best_score) {
222 best = chosen;
223 } else if (chosen_score == best_score) {
224 if (chosen_index != best_index) { // protect against multiple occurrences of IUPAC code-containing barcodes.
225 if (my_duplicates == DuplicateAction::FIRST) {
226 if (chosen_index < best_index) {
227 best_index = chosen_index;
228 }
229 } else if (my_duplicates == DuplicateAction::LAST) {
230 if (chosen_index > best_index) {
231 best_index = chosen_index;
232 }
233 } else {
234 best_index = STATUS_AMBIGUOUS;
235 }
236 }
237 }
238
239 } else if (chosen_index == STATUS_AMBIGUOUS) {
240 if (chosen_score < best_score) {
241 best = chosen;
242 } else if (chosen_score == best_score) {
243 // Ambiguity is infectious. Each ambiguous status indicates that there
244 // are already 2+ barcodes on this score, so it doesn't matter how
245 // many other unambiguous barcodes are here; we're already ambiguous.
246 best_index = STATUS_AMBIGUOUS;
247 }
248 }
249 }
250
251 // To be called in the last step of the recursive search.
252 void scan_final_position_with_mismatch(BarcodeIndex node, int refshift, BarcodeIndex& current_index, int current_mismatches, int& mismatch_cap) const {
253 bool found = false;
254 for (int s = 0; s < NUM_BASES; ++s) {
255 if (s == refshift) {
256 continue;
257 }
258
259 auto candidate = my_pointers[node + s];
260 if (is_barcode_index_ok(candidate)) {
261 if (found) {
262 if (candidate != current_index) { // protect against multiple occurrences of IUPAC-containg barcodes.
263 if (my_duplicates == DuplicateAction::FIRST) {
264 if (current_index > candidate) {
265 current_index = candidate;
266 }
267 } else if (my_duplicates == DuplicateAction::LAST) {
268 if (current_index < candidate) {
269 current_index = candidate;
270 }
271 } else {
272 current_index = STATUS_AMBIGUOUS; // ambiguous, so we quit early.
273 break;
274 }
275 }
276 } else {
277 current_index = candidate;
278 mismatch_cap = current_mismatches;
279 found = true;
280 }
281
282 } else if (candidate == STATUS_AMBIGUOUS) {
283 // If an ambiguity is present on a base at the last position,
284 // and we're accepting a mismatch on the last position, then
285 // we already have at least two known barcodes that match the
286 // input sequence. The behavior of the other bases is irrelevant;
287 // even if they refer to non-ambiguous barcodes, that just adds to the
288 // set of 2+ barcodes that the input sequence already matches.
289 // So, we have no choice but to fail the match due to ambiguity.
290 current_index = STATUS_AMBIGUOUS;
291 mismatch_cap = current_mismatches;
292 break;
293 }
294 }
295 }
296
297public:
298 void optimize() {
299 BarcodeIndex maxed = 0;
300 if (!is_optimal(0, 0, maxed)) {
301 std::vector<BarcodeIndex> replacement;
302 replacement.reserve(my_pointers.size());
303 optimize(0, 0, replacement);
304 my_pointers.swap(replacement);
305 }
306 }
307
308private:
309 // Optimization involves reorganizing the nodes so that the pointers are
310 // always increasing. This promotes memory locality of similar sequences
311 // in a depth-first search (which is what search() does anyway).
312 bool is_optimal(SeqLength i, BarcodeIndex node, BarcodeIndex& maxed) const {
313 ++i;
314 if (i < my_length) {
315 for (int s = 0; s < NUM_BASES; ++s) {
316 auto v = my_pointers[node + s];
317 if (!is_barcode_index_ok(v)) {
318 continue;
319 }
320
321 if (v < maxed) {
322 return false;
323 }
324
325 maxed = v;
326 if (!is_optimal(i, v, maxed)) {
327 return false;
328 }
329 }
330 }
331 return true;
332 }
333
334 void optimize(SeqLength i, BarcodeIndex node, std::vector<BarcodeIndex>& trie) const {
335 auto it = my_pointers.begin() + node;
336 BarcodeIndex new_node = trie.size();
337 trie.insert(trie.end(), it, it + NUM_BASES);
338
339 ++i;
340 if (i < my_length) {
341 for (int s = 0; s < NUM_BASES; ++s) {
342 auto& v = trie[new_node + s];
343 if (!is_barcode_index_ok(v)) {
344 continue;
345 }
346
347 auto original = v;
348 v = trie.size();
349 optimize(i, original, trie);
350 }
351 }
352 }
353};
354
355inline std::pair<BarcodeIndex, int> trie_next_base(char base, BarcodeIndex node, const std::vector<BarcodeIndex>& pointers) {
356 BarcodeIndex current;
357 int shift;
358 switch (base) {
359 case 'A': case 'a':
360 shift = trie_base_shift<'A'>(); current = pointers[node + shift]; break;
361 case 'C': case 'c':
362 shift = trie_base_shift<'C'>(); current = pointers[node + shift]; break;
363 case 'G': case 'g':
364 shift = trie_base_shift<'G'>(); current = pointers[node + shift]; break;
365 case 'T': case 't':
366 shift = trie_base_shift<'T'>(); current = pointers[node + shift]; break;
367 default:
368 shift = -1; current = STATUS_UNMATCHED; break;
369 }
370 return std::make_pair(current, shift);
371}
384public:
389 AnyMismatches() = default;
390
395 AnyMismatches(SeqLength barcode_length, DuplicateAction duplicates) : my_core(barcode_length, duplicates) {}
396
397private:
398 MismatchTrie my_core;
399
400public:
409 TrieAddStatus add(const char* barcode_seq) {
410 return my_core.add(barcode_seq);
411 }
412
417 return my_core.length();
418 }
419
424 return my_core.size();
425 }
426
431 void optimize() {
432 my_core.optimize();
433 }
434
435public:
462
470 Result search(const char* search_seq, int max_mismatches) const {
471 return search(search_seq, 0, 0, 0, max_mismatches);
472 }
473
474private:
475 Result search(const char* seq, SeqLength i, BarcodeIndex node, int mismatches, int& max_mismatches) const {
476 const auto& pointers = my_core.pointers();
477 auto next = trie_next_base(seq[i], node, pointers);
478 auto current = next.first;
479 auto shift = next.second;
480
481 // At the end: we prepare to return the actual values. We also refine
482 // the max number of mismatches so that we don't search for things with
483 // more mismatches than the best hit that was already encountered.
484 SeqLength length = my_core.length();
485 ++i;
486
487 if (i == length) {
488 if (is_barcode_index_ok(current) || current == STATUS_AMBIGUOUS) {
489 max_mismatches = mismatches; // this assignment should always decrease max_mismatches, otherwise the search would have terminated earlier.
490 return Result(current, mismatches);
491 }
492
494 ++mismatches;
495 if (mismatches <= max_mismatches) {
496 my_core.scan_final_position_with_mismatch(node, shift, alt, mismatches, max_mismatches);
497 }
498
499 return Result(alt, mismatches);
500
501 } else {
502 Result best(STATUS_UNMATCHED, max_mismatches + 1);
503 if (is_barcode_index_ok(current)) {
504 best = search(seq, i, current, mismatches, max_mismatches);
505 }
506
507 ++mismatches;
508 if (mismatches <= max_mismatches) {
509 for (int s = 0; s < NUM_BASES; ++s) {
510 if (shift == s) {
511 continue;
512 }
513
514 auto alt = pointers[node + s];
515 if (!is_barcode_index_ok(alt)) {
516 continue;
517 }
518
519 if (mismatches <= max_mismatches) { // check again, just in case max_mismatches changed.
520 auto chosen = search(seq, i, alt, mismatches, max_mismatches);
521 my_core.replace_best_with_chosen(best, best.index, best.mismatches, chosen, chosen.index, chosen.mismatches);
522 }
523 }
524 }
525
526 return best;
527 }
528 }
529};
530
543template<int num_segments_>
545public:
551
557 SegmentedMismatches(std::array<SeqLength, num_segments_> segments, DuplicateAction duplicates) :
558 my_core(std::accumulate(segments.begin(), segments.end(), 0), duplicates),
559 my_boundaries(segments)
560 {
561 for (int i = 1; i < num_segments_; ++i) {
562 my_boundaries[i] += my_boundaries[i-1];
563 }
564 }
565
566private:
567 MismatchTrie my_core;
568 std::array<SeqLength, num_segments_> my_boundaries;
569
570public:
579 TrieAddStatus add(const char* barcode_seq) {
580 return my_core.add(barcode_seq);
581 }
582
587 return my_core.length();
588 }
589
594 return my_core.size();
595 }
596
601 void optimize() {
602 my_core.optimize();
603 }
604
605public:
609 struct Result {
613 Result() : per_segment() {}
625 BarcodeIndex index = 0; // We need index and mismatches to start from zero as we'll be incrementing these in search().
626
632 int mismatches = 0;
633
639 std::array<int, num_segments_> per_segment;
640 };
641
649 Result search(const char* search_seq, const std::array<int, num_segments_>& max_mismatches) const {
650 int total_mismatches = std::accumulate(max_mismatches.begin(), max_mismatches.end(), 0);
651 return search(search_seq, 0, 0, Result(), max_mismatches, total_mismatches);
652 }
653
654private:
655 Result search(const char* seq, SeqLength i, BarcodeIndex segment_id, Result state, const std::array<int, num_segments_>& segment_mismatches, int& total_mismatches) const {
656 // Note that, during recursion, state.index does double duty
657 // as the index of the node on the trie.
658 auto node = state.index;
659
660 const auto& pointers = my_core.pointers();
661 auto next = trie_next_base(seq[i], node, pointers);
662 auto current = next.first;
663 auto shift = next.second;
664
665 // At the end: we prepare to return the actual values. We also refine
666 // the max number of mismatches so that we don't search for things with
667 // more mismatches than the best hit that was already encountered.
668 SeqLength length = my_core.length();
669 ++i;
670
671 if (i == length) {
672 if (is_barcode_index_ok(current) || current == STATUS_AMBIGUOUS) {
673 total_mismatches = state.mismatches; // this assignment should always decrease total_mismatches, otherwise the search would have terminated earlier.
674 state.index = current;
675 return state;
676 }
677
678 state.index = STATUS_UNMATCHED;
679 ++state.mismatches;
680 auto& current_segment_mm = state.per_segment[segment_id];
681 ++current_segment_mm;
682
683 if (state.mismatches <= total_mismatches && current_segment_mm <= segment_mismatches[segment_id]) {
684 my_core.scan_final_position_with_mismatch(node, shift, state.index, state.mismatches, total_mismatches);
685 }
686
687 return state;
688
689 } else {
690 auto next_segment_id = segment_id;
691 if (i == my_boundaries[segment_id]) {
692 ++next_segment_id;
693 }
694
695 Result best;
696 best.index = STATUS_UNMATCHED;
697 best.mismatches = total_mismatches + 1;
698
699 if (is_barcode_index_ok(current)) {
700 state.index = current;
701 best = search(seq, i, next_segment_id, state, segment_mismatches, total_mismatches);
702 }
703
704 ++state.mismatches;
705 auto& current_segment_mm = state.per_segment[segment_id];
706 ++current_segment_mm;
707
708 if (state.mismatches <= total_mismatches && current_segment_mm <= segment_mismatches[segment_id]) {
709 for (int s = 0; s < NUM_BASES; ++s) {
710 if (shift == s) {
711 continue;
712 }
713
714 auto alt = pointers[node + s];
715 if (!is_barcode_index_ok(alt)) {
716 continue;
717 }
718
719 if (state.mismatches <= total_mismatches) { // check again, just in case total_mismatches changed.
720 state.index = alt;
721 auto chosen = search(seq, i, next_segment_id, state, segment_mismatches, total_mismatches);
722 my_core.replace_best_with_chosen(best, best.index, best.mismatches, chosen, chosen.index, chosen.mismatches);
723 }
724 }
725 }
726
727 return best;
728 }
729 }
730};
731
732}
733
734#endif
Search for barcodes with mismatches anywhere.
Definition MismatchTrie.hpp:383
BarcodeIndex size() const
Definition MismatchTrie.hpp:423
Result search(const char *search_seq, int max_mismatches) const
Definition MismatchTrie.hpp:470
void optimize()
Definition MismatchTrie.hpp:431
SeqLength length() const
Definition MismatchTrie.hpp:416
AnyMismatches(SeqLength barcode_length, DuplicateAction duplicates)
Definition MismatchTrie.hpp:395
TrieAddStatus add(const char *barcode_seq)
Definition MismatchTrie.hpp:409
Search for barcodes with segmented mismatches.
Definition MismatchTrie.hpp:544
BarcodeIndex size() const
Definition MismatchTrie.hpp:593
SeqLength length() const
Definition MismatchTrie.hpp:586
TrieAddStatus add(const char *barcode_seq)
Definition MismatchTrie.hpp:579
SegmentedMismatches(std::array< SeqLength, num_segments_ > segments, DuplicateAction duplicates)
Definition MismatchTrie.hpp:557
void optimize()
Definition MismatchTrie.hpp:601
Result search(const char *search_seq, const std::array< int, num_segments_ > &max_mismatches) const
Definition MismatchTrie.hpp:649
Namespace for the kaori barcode-matching library.
Definition BarcodePool.hpp:16
std::size_t SeqLength
Definition utils.hpp:37
constexpr BarcodeIndex STATUS_AMBIGUOUS
Definition utils.hpp:53
constexpr BarcodeIndex STATUS_UNMATCHED
Definition utils.hpp:48
DuplicateAction
Definition utils.hpp:26
bool is_barcode_index_ok(BarcodeIndex index)
Definition utils.hpp:60
std::vector< constchar * >::size_type BarcodeIndex
Definition utils.hpp:43
Results of search().
Definition MismatchTrie.hpp:439
BarcodeIndex index
Definition MismatchTrie.hpp:453
int mismatches
Definition MismatchTrie.hpp:460
Result of the segmented search.
Definition MismatchTrie.hpp:609
int mismatches
Definition MismatchTrie.hpp:632
BarcodeIndex index
Definition MismatchTrie.hpp:625
std::array< int, num_segments_ > per_segment
Definition MismatchTrie.hpp:639
Status of barcode sequence addition to the trie.
Definition MismatchTrie.hpp:26
bool duplicate_replaced
Definition MismatchTrie.hpp:41
bool duplicate_cleared
Definition MismatchTrie.hpp:47
bool has_ambiguous
Definition MismatchTrie.hpp:30
bool is_duplicate
Definition MismatchTrie.hpp:35
Utilites for sequence matching.