#include "KD_Tree.h" #include using std::vector; KD_Tree *KD_Tree::create_KD_tree(const int nNodesPerElem, const int nNodes, const DENS_MAT *nodalCoords, const int nElems, const Array2D &conn) { vector *points = new vector(); // Initialize an empty list of Nodes for (int node = 0; node < nNodes; node++) { // Insert all nodes into list points->push_back(Node(node, (*nodalCoords)(0, node), (*nodalCoords)(1, node), (*nodalCoords)(2, node))); } vector *elements = new vector(); for (int elem = 0; elem < nElems; elem++) { vector nodes = vector(); for (int node = 0; node < nNodesPerElem; node++) { nodes.push_back((*points)[conn(node, elem)]); } elements->push_back(Elem(elem, nodes)); } return new KD_Tree(points, elements); } KD_Tree::~KD_Tree() { delete sortedPts_; delete candElems_; delete leftChild_; delete rightChild_; } KD_Tree::KD_Tree(vector *points, vector *elements, int dimension) : candElems_(elements) { // Set up comparison functions bool (*compare)(Node, Node); if (dimension == 0) { compare = Node::compareX; } else if (dimension == 1) { compare = Node::compareY; } else { compare = Node::compareZ; } // Sort points by their coordinate in the current dimension sort(points->begin(), points->end(), compare); sortedPts_ = points; // Pick the median point as the root of the tree size_t nNodes = points->size(); size_t med = nNodes/2; value_ = (*sortedPts_)[med]; // Recursively construct the left sub-tree vector *leftPts = new vector; vector *leftElems = new vector; // Recursively construct the right sub-tree vector *rightPts = new vector; vector *rightElems = new vector; for (vector::iterator elit = candElems_->begin(); elit != candElems_->end(); elit++) { // Identify elements that should be kept on either side bool foundElemLeft = false; bool foundElemRight = false; for (vector::iterator ndit = elit->second.begin(); ndit != elit->second.end(); ndit++) { // Search this node if (compare(*ndit, value_)) { if (find(leftPts->begin(), leftPts->end(), *ndit) == leftPts->end()) { leftPts->push_back(*ndit); } foundElemLeft = true; } if (compare(value_, *ndit)) { if (find(rightPts->begin(), rightPts->end(), *ndit) == rightPts->end()) { rightPts->push_back(*ndit); } foundElemRight = true; } } if (foundElemLeft) leftElems->push_back(*elit); if (foundElemRight) rightElems->push_back(*elit); } // Create child tree, or NULL if there's nothing to create if (candElems_->size() - leftElems->size() < 4 || leftElems->size() == 0) { leftChild_ = NULL; delete leftPts; delete leftElems; } else { leftChild_ = new KD_Tree(leftPts, leftElems, (dimension+1) % 3); } // Create child tree, or NULL if there's nothing to create if (candElems_->size() - rightElems->size() < 4 || rightElems->size() == 0) { rightChild_ = NULL; delete rightPts; delete rightElems; } else { rightChild_ = new KD_Tree(rightPts, rightElems, (dimension+1) % 3); } } vector KD_Tree::find_nearest_elements(Node query, int dimension) { // if the root coordinate is less than the query coordinate // If the query point is less that the value (split) point of this // tree, either recurse to the left or return this node's elements // if there is no left child. if (query.lessThanInDimension(value_, dimension)) { if (leftChild_ == NULL) { vector result = vector(); for (vector::iterator elem = candElems_->begin(); elem != candElems_->end(); elem++) { result.push_back(elem->first); } return result; } return leftChild_->find_nearest_elements(query, (dimension+1) % 3); } else { if (rightChild_ == NULL) { vector result = vector(); for (vector::iterator elem = candElems_->begin(); elem != candElems_->end(); elem++) { result.push_back(elem->first); } return result; } return rightChild_->find_nearest_elements(query, (dimension+1) % 3); } } vector > KD_Tree::getElemIDs(int depth) { vector > result; vector > temp; assert(depth >= 0 ); if (depth == 0) { vector candElemIDs; vector::iterator it; for(it = candElems_->begin(); it != candElems_->end(); ++it) { candElemIDs.push_back((*it).first); } sort(candElemIDs.begin(), candElemIDs.end()); result.push_back(candElemIDs); } else if (leftChild_ == NULL || rightChild_ == NULL) { // Insert all nodes at this level once, // then insert a bunch of empty vectors. temp = this->getElemIDs(0); result.insert(result.end(), temp.begin(), temp.end()); int numRequested = floor(pow(2,depth)); for (int i = 0; i < numRequested - 1; ++i) { vector emptyVec; result.push_back(emptyVec); } } else { --depth; temp = leftChild_->getElemIDs(depth); result.insert(result.end(), temp.begin(), temp.end()); temp = rightChild_->getElemIDs(depth); result.insert(result.end(), temp.begin(), temp.end()); } return result; }