From a3858735ec68d518f830307da65606694345dca1 Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Mon, 1 Jun 2015 18:12:25 -0300 Subject: Bugfix klDivergence. --- measures.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'measures.cpp') diff --git a/measures.cpp b/measures.cpp index dff8494..8829ee8 100644 --- a/measures.cpp +++ b/measures.cpp @@ -43,7 +43,8 @@ void mp::klDivergence(const arma::mat &P, const arma::mat &Q, arma::vec diver) double mp::klDivergence(const arma::rowvec &pi, const arma::rowvec &qi) { - return arma::accu(removeJ % (pi % arma::log(pi / qi))); + // Pii and Qii should both be 1, zeroing the i-th term in the sum below + return arma::accu(pi % arma::log(pi / qi)); } arma::mat mp::d2p(const arma::mat &D, const arma::vec &sigmas) @@ -58,7 +59,6 @@ void mp::d2p(const arma::mat &D, const arma::vec &sigmas, arma::mat &P) /* * WARNING: assumes D and sigmas are already squared */ - assert(D.n_rows == D.n_cols); assert(P.n_rows == P.n_cols); assert(D.n_rows == P.n_rows); @@ -72,4 +72,4 @@ void mp::d2p(const arma::mat &D, const arma::vec &sigmas, arma::mat &P) for (arma::uword j = 0; j < n; j++) P(i, j) = exp(-D(i, j) / sigmas(i)) / den; } -} \ No newline at end of file +} -- cgit v1.2.3