Commit 5b1631c0 authored by Overduin, Sam's avatar Overduin, Sam Committed by Overduin, Sam
Browse files

Improved taxonomy assignment & implemented taxonomy transition filter.

barcode_index.hpp:
- Initialized TaxIdEncoder with "0" taxatree
- Added func TaxIdEncoder::ToTaxaTreeVector
- Bugfix EdgeEntry::GetTaxonomy()

barcode_info_extractor.hpp:
- Added funcs: ToTaxaTreeVector, TaxaTreeFromTaxId, GetTaxaTreeFromEdge

barcode_index_construction.cpp:
- Moved ToTaxaTreeVector to extractor & barcode_index
- Improved majority_vote_lca to actually use LCA

construction_callers.cpp & hpp:
- Added TaxaBreakConstructorCaller to break mismatched taxa in
  contracted assembly graph

read_cloud_connection_conditions.cpp:
- Added TaxaBreakPredicate that returns false when transition
  (scaffold_edge) has incompatible taxonomy

scaffold_graph_construction_pipeline.cpp:
- Added TaxaBreak routine to basic mode & scaffolding mode
parent ba6cfe86
......@@ -84,7 +84,8 @@ namespace barcode_index {
TaxIdEncoder():
codes_(),
codes_rev_()
{ }
{ string empty_taxa_tree = "0"; //initialise with empty taxatree
AddTaxaTree(empty_taxa_tree);}
void AddTaxaTree(string &taxatree) {
auto it = codes_.find(taxatree);
......@@ -106,6 +107,20 @@ namespace barcode_index {
return std::stoi(taxid_str);
}
std::vector<TaxId> ToTaxaTreeVector(const std::string& taxa_tree_string, const char sep='.') const {
std::vector<TaxId> taxa_tree_vect;
std::string taxa;
TaxId taxid;
std::stringstream tree_stream(taxa_tree_string); // Insert the string into a stream
while(getline(tree_stream, taxa, sep)) {
// string to uint64_t
std::stringstream taxa_stream(taxa);
taxa_stream >> taxid;
taxa_tree_vect.push_back(taxid);
}
return taxa_tree_vect;
}
TaxId GetCode (const string& taxatree) const {
VERIFY(codes_.find(taxatree) != codes_.end());
return codes_.at(taxatree);
......@@ -590,7 +605,7 @@ namespace barcode_index {
taxonomy_ = taxid;
}
TaxId GetTaxonomy() {
TaxId GetTaxonomy() const {
return taxonomy_;
}
......
......@@ -153,6 +153,22 @@ namespace barcode_index {
return index_.edge_to_entry_.at(edge_id_).taxid_distribution_.at(taxid).GetCount();
}
std::vector<TaxId> ToTaxaTreeVector(const std::string& taxa_tree_string, const char sep='.') const {
return index_.taxatree_codes_.ToTaxaTreeVector(taxa_tree_string, sep);
}
string TaxaTreeFromTaxId(TaxId& taxid) const {
return index_.taxatree_codes_.GetTaxaTree(taxid);
}
string GetTaxaTreeFromEdge(const EdgeId& edge_id_) const {
//INFO("EdgeId: " << edge_id_);
TaxId taxid = index_.edge_to_entry_.at(edge_id_).GetTaxonomy();
//INFO("GetTaxonomy: "<< taxid);
string taxatree = TaxaTreeFromTaxId(taxid);
return taxatree;
}
typename taxid_distribution_t::const_iterator taxid_iterator_begin(const EdgeId &edge) const {
auto entry_it = index_.GetEntryHeadsIterator(edge);
return entry_it->second.taxid_begin(); // second is EdgeEntry
......
......@@ -146,12 +146,14 @@ CompositeConnectionConstructorCaller::CompositeConnectionConstructorCaller(
gp_(gp), main_extractor_(main_extractor), long_edge_extractor_(barcode_extractor),
unique_storage_(unique_storage), search_parameter_pack_(search_parameter_pack),
scaff_con_configs_(scaff_con_configs), max_threads_(max_threads), scaffolding_mode_(scaffolding_mode) {}
EdgeSplitConstructorCaller::EdgeSplitConstructorCaller(
const Graph &g_,
std::shared_ptr<barcode_index::SimpleScaffoldVertexIndexInfoExtractor> barcode_extractor_,
std::size_t max_threads_)
: IterativeScaffoldGraphConstructorCaller("Conjugate filter"),
g_(g_), barcode_extractor_(barcode_extractor_), max_threads_(max_threads_) {}
std::shared_ptr<path_extend::scaffolder::ScaffoldGraphConstructor> EdgeSplitConstructorCaller::GetScaffoldGraphConstuctor(
const ScaffolderParams &params,
const ScaffoldGraph &scaffold_graph) const {
......@@ -164,6 +166,28 @@ std::shared_ptr<path_extend::scaffolder::ScaffoldGraphConstructor> EdgeSplitCons
max_threads_);
return constructor;
}
//TODO: Add taxonomy getter to barcode_extractor_
TaxaBreakConstructorCaller::TaxaBreakConstructorCaller(
const Graph &g_,
std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor_,
std::size_t max_threads_)
: IterativeScaffoldGraphConstructorCaller("Taxonomy based filter"),
g_(g_), barcode_extractor_(barcode_extractor_), max_threads_(max_threads_) {}
std::shared_ptr<path_extend::scaffolder::ScaffoldGraphConstructor> TaxaBreakConstructorCaller::GetScaffoldGraphConstuctor(
const read_cloud::ScaffolderParams &params,
const ScaffoldGraph &scaffold_graph) const {
auto predicate = std::make_shared<TaxaBreakPredicate>(g_, barcode_extractor_);
auto constructor =
std::make_shared<path_extend::scaffolder::PredicateScaffoldGraphFilter>(g_,
scaffold_graph,
predicate,
max_threads_);
return constructor;
}
TransitiveConstructorCaller::TransitiveConstructorCaller(const Graph &g_,
std::size_t max_threads_)
: IterativeScaffoldGraphConstructorCaller("Transitive filter"),
......
......@@ -159,6 +159,29 @@ class EdgeSplitConstructorCaller : public IterativeScaffoldGraphConstructorCalle
std::shared_ptr<barcode_index::SimpleScaffoldVertexIndexInfoExtractor> barcode_extractor_;
std::size_t max_threads_;
};
/** ConstructorCaller that filters wrong taxonomic transitions.
* Eg 1.2.3 -/-> 1.2.4 (mismatch in taxonomy so wrong),
* 1.2.3 --> 0 (taxonomy undefined so not wrong),
* 1.2.3 --> 1.2 (within same hierarchy so not wrong).
*/
class TaxaBreakConstructorCaller : public IterativeScaffoldGraphConstructorCaller {
public:
using IterativeScaffoldGraphConstructorCaller::ScaffoldGraph;
TaxaBreakConstructorCaller(const Graph &g_,
std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor_,
std::size_t max_threads_);
std::shared_ptr<scaffolder::ScaffoldGraphConstructor> GetScaffoldGraphConstuctor(
const read_cloud::ScaffolderParams &params,
const ScaffoldGraph &scaffold_graph) const override;
private:
const Graph &g_;
std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor_;
std::size_t max_threads_;
};
/** ConstructorCaller that filters transitive connections.
*/
class TransitiveConstructorCaller : public IterativeScaffoldGraphConstructorCaller {
......
......@@ -110,6 +110,44 @@ bool ReadCloudMiddleDijkstraPredicate::Check(const scaffold_graph::ScaffoldGraph
return false;
}
TaxaBreakPredicate::TaxaBreakPredicate(
const Graph &g_,
std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor)
: g_(g_),
barcode_extractor_(barcode_extractor) {}
bool TaxaBreakPredicate::Check(const ScaffoldEdge &scaffold_edge) const {
//barcode_extractor_. ;
static const std::string null_taxa = "0";
auto first = scaffold_edge.getStart();
auto second = scaffold_edge.getEnd();
DEBUG("In TaxaBreakPredicate, Edge_id: " << scaffold_edge.getId() << " length: " << scaffold_edge.getLength());
auto seq1 = first.GetSequence(g_);
auto seq2 = second.GetSequence(g_);
DEBUG("Vertex_1_id: " << first.int_id() << " Taxonomy: " << barcode_extractor_->GetTaxaTreeFromEdge(first.int_id()) << " Length: " << seq1.size());
DEBUG("Vertex_2_id: " << second.int_id() << " Taxonomy: " << barcode_extractor_->GetTaxaTreeFromEdge(second.int_id()) << " Length: " << seq2.size());
std::string taxatree_1 = barcode_extractor_->GetTaxaTreeFromEdge(first.int_id());
std::string taxatree_2 = barcode_extractor_->GetTaxaTreeFromEdge(second.int_id());
if (taxatree_1 == null_taxa || taxatree_2 == null_taxa ||
taxatree_1.find(taxatree_2) != std::string::npos || taxatree_2.find(taxatree_1) != std::string::npos) {
DEBUG("Taxatree match!");
return true;
}
else {
DEBUG("Taxatree mismatch!");
}
return false;
//TODO: implement taxonomy in barcode_extractor and vertexes!!
//barcode_extractor_->
//get taxonomy for first and second
//first.
//bool result = true;
//return result;
}
EdgeSplitPredicate::EdgeSplitPredicate(
const Graph &g_,
std::shared_ptr<barcode_index::SimpleScaffoldVertexIndexInfoExtractor> barcode_extractor,
......
......@@ -107,6 +107,22 @@ class CompositeConnectionPredicate : public ScaffoldEdgePredicate {
DECL_LOGGER("CompositeConnectionPredicate");
};
class TaxaBreakPredicate : public ScaffoldEdgePredicate {
public:
using ScaffoldEdgePredicate::ScaffoldEdge;
typedef barcode_index::BarcodeId BarcodeId;
TaxaBreakPredicate(const Graph &g_,
std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor);
bool Check(const ScaffoldEdge &scaffold_edge) const override;
private:
const Graph &g_;
std::shared_ptr<barcode_index::FrameBarcodeIndexInfoExtractor> barcode_extractor_;
DECL_LOGGER("TaxaBreakPredicate");
};
class EdgeSplitPredicate : public ScaffoldEdgePredicate {
public:
using ScaffoldEdgePredicate::ScaffoldEdge;
......
......@@ -99,6 +99,7 @@ std::vector<ScaffoldGraphConstructionPipeline::ResultT> ScaffoldGraphConstructio
ScaffoldGraphPipelineConstructor::ScaffoldGraphPipelineConstructor(const ReadCloudConfigsT &configs, const Graph &g) :
read_cloud_configs_(configs), g_(g) {}
std::shared_ptr<ScaffoldGraphPipelineConstructor::ScaffoldVertexExtractor> ScaffoldGraphPipelineConstructor::ConstructSimpleEdgeIndex(
const std::set<ScaffoldGraphPipelineConstructor::ScaffoldVertex> &scaffold_vertices,
ScaffoldGraphPipelineConstructor::BarcodeIndexPtr barcode_extractor,
......@@ -241,6 +242,8 @@ std::vector<std::shared_ptr<IterativeScaffoldGraphConstructorCaller>> FullScaffo
unique_storage_, search_parameter_pack_,
read_cloud_configs_.scaff_con, max_threads_,
scaffolding_mode));
iterative_constructor_callers.push_back(
std::make_shared<TaxaBreakConstructorCaller>(gp_.g, barcode_extractor_, max_threads_));
const size_t min_pipeline_length = read_cloud_configs_.long_edge_length_lower_bound;
bool launch_full_pipeline = min_length_ > min_pipeline_length;
......@@ -316,6 +319,8 @@ std::vector<std::shared_ptr<IterativeScaffoldGraphConstructorCaller>> MergingSca
split_scaffold_index_extractor,
max_threads_));
iterative_constructor_callers.push_back(std::make_shared<TransitiveConstructorCaller>(g_, max_threads_));
iterative_constructor_callers.push_back(
std::make_shared<TaxaBreakConstructorCaller>(g_, barcode_extractor_, max_threads_));
return iterative_constructor_callers;
}
} //path_extend
......
......@@ -18,21 +18,10 @@ namespace debruijn_graph {
return has_read_clouds;
}
std::vector<TaxId> ToTaxaTreeVector(const std::string& taxa_tree_string, const char sep='.') {
std::vector<TaxId> taxa_tree_vect;
std::string taxa;
TaxId taxid;
std::stringstream tree_stream(taxa_tree_string); // Insert the string into a stream
while(getline(tree_stream, taxa, sep)) {
// string to uint64_t
std::stringstream taxa_stream(taxa);
taxa_stream >> taxid;
taxa_tree_vect.push_back(taxid);
}
return taxa_tree_vect;
}
TaxId majority_vote_lca(std::vector<string> &taxatree_str_vec, std::vector<size_t> &count_vec) {
TaxId majority_vote_lca(std::vector<string>& taxatree_str_vec, std::vector<size_t>& count_vec,
const FrameBarcodeIndexInfoExtractor& extractor) {
std::vector<std::vector<TaxId>> taxatree_vec;
size_t longest_lineage = 0;
size_t total_counts = 0;
......@@ -40,7 +29,7 @@ namespace debruijn_graph {
// transfer taxatree_strings to taxid_vectors.
for ( const std::string& taxatree_str : taxatree_str_vec ) {
if (taxatree_str != "0") {
std::vector<TaxId> taxid_vec = ToTaxaTreeVector(taxatree_str, '.');
std::vector<TaxId> taxid_vec = extractor.ToTaxaTreeVector(taxatree_str, '.');
taxatree_vec.push_back(taxid_vec);
longest_lineage = std::max(taxid_vec.size(), longest_lineage);
}
......@@ -57,33 +46,62 @@ namespace debruijn_graph {
total_counts += count;
}
size_t min_majority = total_counts/2.0;
//min_majority = 3; //temporary for testing purposes ToDo: remove this one.
// find index of taxid with highest count
auto most_common_index = std::distance(count_vec.begin(),
std::max_element(count_vec.begin(), count_vec.end()));
if (count_vec[most_common_index] > min_majority and total_counts > (0.2*(total_counts + null_counts))){
return taxatree_vec[most_common_index].back();
TaxId lca = 0;
VERIFY_MSG(taxatree_vec.size() == count_vec.size(), "ERROR: taxatree_vec.size is not count_vec.size in max_lca");
if (count_vec.size() > 0) { //in this case only taxa 0 was part of taxatree_str_vec.
if (total_counts > (0.20 * (total_counts + null_counts))) { //minimum 20% assigned taxids.
size_t i = 0;
TaxId proposed_lca = 0;
size_t prop_lca_counts = total_counts;
while (i < longest_lineage && prop_lca_counts >= min_majority) {
prop_lca_counts = 0;
if (i < taxatree_vec[most_common_index].size()){
proposed_lca = taxatree_vec[most_common_index][i];
} else {
size_t max_sub_count = 0;
for ( auto taxatree : taxatree_vec ) {
if ( i < taxatree.size() && count_vec[i] > max_sub_count ) {
//get most common taxa as proposed_lca within taxatree that aren't too long.
proposed_lca = taxatree[i];
max_sub_count = count_vec[i];
}
}
}
size_t count_vec_pos = 0;
for ( auto taxatree : taxatree_vec ) {
if ( i < taxatree.size() && taxatree[i] == proposed_lca) {
prop_lca_counts += count_vec[count_vec_pos];
}
count_vec_pos += 1;
}
if ( prop_lca_counts >= min_majority ) {
lca = proposed_lca;
}
++i;
}
}
}
return 0;
return lca;
}
TaxId last_common_ancestor(std::vector<string> &taxatree_vec, std::vector<size_t> &count_vec) {
VERIFY_MSG(taxatree_vec.size() == count_vec.size(), "ERROR: taxatree_vec.size is not count_vec.size during lca");
TaxId last_common_ancestor(std::vector<string> &taxatree_vec, std::vector<size_t> &count_vec,
const FrameBarcodeIndexInfoExtractor& extractor) {
VERIFY_MSG(taxatree_vec.size() == count_vec.size(), "ERROR: taxatree_vec.size is not count_vec.size before lca");
// wrapper to change underlying lca_algorithm.
TaxId lca = 0;
lca = majority_vote_lca(taxatree_vec, count_vec);
//INFO("Starting LCA");
lca = majority_vote_lca(taxatree_vec, count_vec, extractor);
//INFO("Done LCA)");
return lca;
}
void assign_taxonomy_to_edges(barcode_index::FrameBarcodeIndex<debruijn_graph::DeBruijnGraph>& barcodeindex,
const FrameBarcodeIndexInfoExtractor& extractor){
//for edge in edge_iterator
// get taxid_lst + count_lst
// taxid_lst to taxatree_lst
// LCA = lca_method(taxatree_lst, count_lst)
// edge.SetTaxonomy(LCA)
for ( auto &p : barcodeindex.edge_to_entry_ ) {
std::vector<string> taxatree_vector;
std::vector<size_t> count_vector;
......@@ -101,13 +119,14 @@ namespace debruijn_graph {
count = extractor.GetTaxidCount(edge, taxid);
count_vector.push_back(count);
}
TaxId lca = last_common_ancestor(taxatree_vector, count_vector);
p.second.SetTaxonomy(lca);
INFO("EdgeId: " << edge);
DEBUG("EdgeId: " << edge << " Frame_size: " << p.second.GetFrameSize() << " # of frames: " << p.second.GetNumberOfFrames());
for (size_t i = 0; i != taxatree_vector.size(); i++ ){
INFO("TaxaTree: " << taxatree_vector[i] << ", Count: " << count_vector[i]);
DEBUG("TaxaTree: " << taxatree_vector[i] << ", Count: " << count_vector[i]);
}
INFO("Taxonomy: " << barcodeindex.edge_to_entry_.at(edge).GetTaxonomy())
// Beware that lca function can mess with count_vector length so just use once.
TaxId lca = last_common_ancestor(taxatree_vector, count_vector, extractor);
p.second.SetTaxonomy(lca);
DEBUG("Taxonomy: " << barcodeindex.edge_to_entry_.at(edge).GetTaxonomy())
}
}
......@@ -137,7 +156,7 @@ namespace debruijn_graph {
mapper_builder.FillMap(reads, graph_pack.index, graph_pack.kmer_mapper);
INFO("Barcode index construction finished.");
FrameBarcodeIndexInfoExtractor extractor(graph_pack.barcode_mapper, graph_pack.g);
assign_taxonomy_to_edges(graph_pack.barcode_mapper, extractor); // function remade as method of barcode_mapper
assign_taxonomy_to_edges(graph_pack.barcode_mapper, extractor);
INFO("Taxonomy assigned to all edges");
size_t length_threshold = cfg::get().pe_params.read_cloud.long_edge_length_lower_bound;
INFO("Average barcode coverage: " + std::to_string(extractor.AverageBarcodeCoverage(length_threshold)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment