41 static constexpr int NUM_BASES = 4;
58 length(barcode_length),
60 duplicates(duplicates),
93 void next(
int shift,
int& position) {
94 auto& current = pointers[position + shift];
96 current = pointers.size();
104 void end(
int shift,
int position, AddStatus& status) {
105 auto& current = pointers[position + shift];
109 case DuplicateAction::FIRST:
111 case DuplicateAction::LAST:
112 status.duplicate_replaced =
true;
115 case DuplicateAction::NONE:
116 status.duplicate_cleared =
true;
119 case DuplicateAction::ERROR:
120 throw std::runtime_error(
"duplicate sequences detected (" +
121 std::to_string(current + 1) +
", " +
122 std::to_string(counter + 1) +
") when constructing the trie");
128 status.is_duplicate =
true;
132 void recursive_add(
size_t i,
int position,
const char* barcode_seq, AddStatus& status) {
136 auto shift = base_shift<true>(barcode_seq[i]);
137 if (shift == NON_STANDARD_BASE) {
141 if ((++i) == length) {
142 end(shift, position, status);
145 next(shift, position);
150 status.has_ambiguous =
true;
152 auto process = [&](
char base) ->
void {
153 auto shift = base_shift(base);
154 if (i + 1 == length) {
155 end(shift, position, status);
157 auto curpos = position;
159 recursive_add(i + 1, curpos, barcode_seq, status);
163 switch(barcode_seq[i]) {
165 process(
'A'); process(
'G');
break;
167 process(
'C'); process(
'T');
break;
169 process(
'C'); process(
'G');
break;
171 process(
'A'); process(
'T');
break;
173 process(
'G'); process(
'T');
break;
175 process(
'A'); process(
'C');
break;
177 process(
'C'); process(
'G'); process(
'T');
break;
179 process(
'A'); process(
'G'); process(
'T');
break;
181 process(
'A'); process(
'C'); process(
'T');
break;
183 process(
'A'); process(
'C'); process(
'G');
break;
185 process(
'A'); process(
'C'); process(
'G'); process(
'T');
break;
187 throw std::runtime_error(
"unknown base '" + std::string(1, barcode_seq[i]) +
"' detected when constructing the trie");
202 recursive_add(0, 0, barcode_seq, status);
227 std::vector<int> pointers;
229 static constexpr int NON_STANDARD_BASE = -1;
231 template<
bool allow_unknown = false>
232 static int base_shift(
char base) {
247 if constexpr(allow_unknown) {
248 shift = NON_STANDARD_BASE;
250 throw std::runtime_error(
"unknown base '" + std::string(1, base) +
"' detected when constructing the trie");
263 DuplicateAction duplicates;
266 template<
class SearchResult_>
267 void replace_best_with_chosen(SearchResult_& best,
int& best_index,
int best_score,
const SearchResult_& chosen,
int chosen_index,
int chosen_score)
const {
268 if (chosen_index >= 0) {
269 if (chosen_score < best_score) {
271 }
else if (chosen_score == best_score) {
272 if (chosen_index != best_index) {
273 if (duplicates == DuplicateAction::FIRST) {
274 if (chosen_index < best_index) {
275 best_index = chosen_index;
277 }
else if (duplicates == DuplicateAction::LAST) {
278 if (chosen_index > best_index) {
279 best_index = chosen_index;
288 if (chosen_score < best_score) {
290 }
else if (chosen_score == best_score) {
300 void scan_final_position_with_mismatch(
int node,
int refshift,
int& current_index,
int current_mismatches,
int& mismatch_cap)
const {
302 for (
int s = 0; s < NUM_BASES; ++s) {
307 int candidate = pointers[node + s];
308 if (candidate >= 0) {
310 if (candidate != current_index) {
311 if (duplicates == DuplicateAction::FIRST) {
312 if (current_index > candidate) {
313 current_index = candidate;
315 }
else if (duplicates == DuplicateAction::LAST) {
316 if (current_index < candidate) {
317 current_index = candidate;
325 current_index = candidate;
326 mismatch_cap = current_mismatches;
339 mismatch_cap = current_mismatches;
358 if (!is_optimal(0, 0, maxed)) {
359 std::vector<int> replacement;
360 replacement.reserve(pointers.size());
362 pointers.swap(replacement);
370 bool is_optimal(
int node,
size_t pos,
int& maxed)
const {
373 for (
int s = 0; s < NUM_BASES; ++s) {
374 auto v = pointers[node + s];
384 if (!is_optimal(v, pos, maxed)) {
392 void optimize(
int node,
size_t pos, std::vector<int>& trie)
const {
393 auto it = pointers.begin() + node;
394 size_t new_node = trie.size();
395 trie.insert(trie.end(), it, it + NUM_BASES);
399 for (
int s = 0; s < NUM_BASES; ++s) {
400 auto& v = trie[new_node + s];
446 std::pair<int, int>
search(
const char* search_seq,
int max_mismatches)
const {
447 return search(search_seq, 0, 0, 0, max_mismatches);
451 std::pair<int, int>
search(
const char* seq,
size_t pos,
int node,
int mismatches,
int& max_mismatches)
const {
452 int shift = base_shift<true>(seq[pos]);
453 int current = (shift >= 0 ? pointers[node + shift] :
STATUS_MISSING);
458 if (pos + 1 == length) {
460 max_mismatches = mismatches;
461 return std::make_pair(current, mismatches);
466 if (mismatches <= max_mismatches) {
467 scan_final_position_with_mismatch(node, shift, alt, mismatches, max_mismatches);
470 return std::make_pair(alt, mismatches);
477 best =
search(seq, pos, current, mismatches, max_mismatches);
481 if (mismatches <= max_mismatches) {
482 for (
int s = 0; s < NUM_BASES; ++s) {
487 int alt = pointers[node + s];
492 if (mismatches <= max_mismatches) {
493 auto chosen =
search(seq, pos, alt, mismatches, max_mismatches);
494 replace_best_with_chosen(best, best.first, best.second, chosen, chosen.first, chosen.second);
528 MismatchTrie(std::accumulate(segments.begin(), segments.end(), 0), duplicates),
531 for (
size_t i = 1; i < num_segments; ++i) {
532 boundaries[i] += boundaries[i-1];
577 Result search(
const char* search_seq,
const std::array<int, num_segments>& max_mismatches)
const {
578 int total_mismatches = std::accumulate(max_mismatches.begin(), max_mismatches.end(), 0);
579 return search(search_seq, 0, 0,
Result(), max_mismatches, total_mismatches);
588 const std::array<int, num_segments>& segment_mismatches,
589 int& total_mismatches
593 int node = state.
index;
595 int shift = base_shift<true>(seq[pos]);
596 int current = (shift >= 0 ? pointers[node + shift] :
STATUS_MISSING);
601 if (pos + 1 == length) {
603 total_mismatches = state.total;
604 state.index = current;
610 auto& current_segment_mm = state.per_segment[segment_id];
611 ++current_segment_mm;
613 if (state.total <= total_mismatches && current_segment_mm <= segment_mismatches[segment_id]) {
614 scan_final_position_with_mismatch(node, shift, state.index, state.total, total_mismatches);
620 auto next_pos = pos + 1;
621 auto next_segment_id = segment_id;
622 if (
static_cast<int>(next_pos) == boundaries[segment_id]) {
628 best.total = total_mismatches + 1;
631 state.index = current;
632 best =
search(seq, next_pos, next_segment_id, state, segment_mismatches, total_mismatches);
636 auto& current_segment_mm = state.per_segment[segment_id];
637 ++current_segment_mm;
639 if (state.total <= total_mismatches && current_segment_mm <= segment_mismatches[segment_id]) {
640 for (
int s = 0; s < NUM_BASES; ++s) {
645 int alt = pointers[node + s];
650 if (state.total <= total_mismatches) {
652 auto chosen =
search(seq, next_pos, next_segment_id, state, segment_mismatches, total_mismatches);
653 replace_best_with_chosen(best, best.index, best.total, chosen, chosen.index, chosen.total);
662 std::array<int, num_segments> boundaries;