From 5d80eef7933cf65c9d68b9d3e9259bc9598c2bd5 Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Wed, 31 Aug 2016 09:46:08 -0300 Subject: Updated KNN. --- measures.cpp | 31 +++++++++++++++++-------------- mp.h | 2 +- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/measures.cpp b/measures.cpp index 2a3908f..7c30cd9 100644 --- a/measures.cpp +++ b/measures.cpp @@ -8,31 +8,34 @@ static const float EPSILON = 1e-6f; -arma::vec mp::neighborhoodPreservation(const arma::mat &distA, - const arma::mat &distB, - arma::uword k) +void mp::neighborhoodPreservation(const arma::mat &distA, + const arma::mat &distB, + arma::uword k, + arma::vec &v) { - int n = uintToInt(distA.n_rows); - arma::vec np(n); + int n = uintToInt(v.n_elem); - #pragma omp parallel for shared(np, n) + #pragma omp parallel for shared(v, n) for (int i = 0; i < n; i++) { - arma::uvec nnA(k); - arma::uvec nnB(k); - arma::vec dist(k); + //arma::uvec nnA(k); + //arma::uvec nnB(k); + //arma::vec dist(k); + + //mp::knn(distA, i, k, nnA, dist); + //mp::knn(distB, i, k, nnB, dist); - mp::knn(distA, i, k, nnA, dist); - mp::knn(distB, i, k, nnB, dist); + arma::uvec nnA = arma::sort_index(distA.col(i)); + nnA = nnA.subvec(2, k + 1); + arma::uvec nnB = arma::sort_index(distB.col(i)); + nnB = nnB.subvec(2, k + 1); std::sort(nnA.begin(), nnA.end()); std::sort(nnB.begin(), nnB.end()); arma::uword l; for (l = 0; nnA[l] == nnB[l] && l < k; l++); - np[i] = ((double) l) / k; + v[i] = ((double) l) / k; } - - return np; } arma::vec mp::silhouette(const arma::mat &distA, diff --git a/mp.h b/mp.h index 8f555a2..d2d0df4 100644 --- a/mp.h +++ b/mp.h @@ -13,7 +13,7 @@ arma::mat dist(const arma::mat &X, DistFunc dfunc = euclidean); void knn(const arma::mat &dmat, arma::uword i, arma::uword k, arma::uvec &nn, arma::vec &dist); // Evaluation measures -arma::vec neighborhoodPreservation(const arma::mat &distA, const arma::mat &distB, arma::uword k = 10); +void neighborhoodPreservation(const arma::mat &distA, const arma::mat &distB, arma::uword k, arma::vec &v); arma::vec silhouette(const arma::mat &distA, const arma::mat &distB, const arma::vec &labels); void aggregatedError(const arma::mat &distX, const arma::mat &distY, arma::vec &v); -- cgit v1.2.3