From 4e0b46a727f6ea727b9e7920150609c58ce65fce Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Sat, 30 May 2015 02:00:08 -0300 Subject: Added tSNE. Code improvements. --- tsne.cpp | 124 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 tsne.cpp (limited to 'tsne.cpp') diff --git a/tsne.cpp b/tsne.cpp new file mode 100644 index 0000000..50022eb --- /dev/null +++ b/tsne.cpp @@ -0,0 +1,124 @@ +#include "mp.h" + +#include +#include + +static const double ETA = 500; +static const double MIN_GAIN = 1e-2; +static const double EPSILON = 1e-12; +static const double INITIAL_MOMENTUM = 0.5; +static const double FINAL_MOMENTUM = 0.8; +static const double EARLY_EXAGGERATION = 4.; +static const double GAIN_FRACTION = 0.2; + +static const int MOMENTUM_THRESHOLD_ITER = 20; +static const int EXAGGERATION_THRESHOLD_ITER = 100; +static const int MAX_BINSEARCH_TRIES = 50; + +static void calcP(const arma::mat &X, arma::mat &P, double perplexity, double tol = 1e-5); +static double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi); + +arma::mat mp::tSNE(const arma::mat &X, arma::uword k, double perplexity, arma::uword nIter) +{ + arma::mat Y(X.n_rows, k); + mp::tSNE(X, Y, perplexity, nIter); + return Y; +} + +void mp::tSNE(const arma::mat &X, arma::mat &Y, double perplexity, arma::uword nIter) +{ + double momentum; + arma::uword n = X.n_rows; + arma::uword k = Y.n_cols; + arma::mat Q(n, n); + arma::mat dY(n, k), + gains(n, k, arma::fill::ones), + iY(n, k, arma::fill::zeros); + + arma::mat P(n, n, arma::fill::zeros); + calcP(X, P, perplexity); + P = (P + P.t()); + P /= arma::accu(P); + P *= EARLY_EXAGGERATION; + P.transform([](double v) { return std::max(v, EPSILON); }); + + for (arma::uword iter = 0; iter < nIter; iter++) { + arma::vec sumY = arma::sum(Y % Y, 1); + arma::mat num = -2. * (Y * Y.t()); + num.each_col() += sumY; + arma::inplace_trans(num); + num.each_col() += sumY; + num = 1. / (1. + num); + num.diag() *= 0; + Q = num / arma::accu(num); + Q.transform([](double v) { return std::max(v, EPSILON); }); + + for (arma::uword i = 0; i < n; i++) { + arma::mat tmp = -Y; + tmp.each_row() += Y.row(i); + tmp.each_col() %= (P.col(i) - Q.col(i)) % num.col(i); + dY.row(i) = arma::sum(tmp, 0); + } + + momentum = (iter < MOMENTUM_THRESHOLD_ITER) ? INITIAL_MOMENTUM : FINAL_MOMENTUM; + gains = (gains + GAIN_FRACTION) % ((dY > 0) != (iY > 0)) + + (gains * (1 - GAIN_FRACTION)) % ((dY > 0) == (iY > 0)); + gains.transform([](double v) { return std::max(v, MIN_GAIN); }); + iY = momentum * iY - ETA * (gains % dY); + Y += iY; + Y.each_row() -= mean(Y, 0); + + if (iter == EXAGGERATION_THRESHOLD_ITER) + P /= EARLY_EXAGGERATION; // remove early exaggeration + } +} + +static void calcP(const arma::mat &X, arma::mat &P, double perplexity, double tol) { + arma::colvec sumX = arma::sum(X % X, 1); + arma::mat D = -2 * (X * X.t()); + D.each_col() += sumX; + arma::inplace_trans(D); + D.each_col() += sumX; + D.diag() *= 0; + double logU = log(perplexity); + arma::rowvec beta(X.n_rows, arma::fill::ones); + + arma::rowvec Pi(X.n_rows); + for (arma::uword i = 0; i < X.n_rows; i++) { + double betaMin = -arma::datum::inf; + double betaMax = arma::datum::inf; + arma::rowvec Di = D.row(i); + double h = hBeta(Di, beta[i], Pi); + + double hDiff = h - logU; + for (int tries = 0; fabs(hDiff) > tol && tries < MAX_BINSEARCH_TRIES; tries++) { + if (hDiff > 0) { + betaMin = beta[i]; + if (betaMax == arma::datum::inf || betaMax == -arma::datum::inf) + beta[i] *= 2; + else + beta[i] = (beta[i] + betaMax) / 2.; + } else { + betaMax = beta[i]; + if (betaMin == arma::datum::inf || betaMin == -arma::datum::inf) + beta[i] /= 2; + else + beta[i] = (beta[i] + betaMin) / 2.; + } + + + h = hBeta(Di, beta[i], Pi); + hDiff = h - logU; + } + + P.row(i) = Pi; + } +} + +static inline double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi) { + Pi = arma::exp(-Di * beta); + double sumPi = arma::accu(Pi); + double h = log(sumPi) + beta * arma::accu(Di % Pi) / sumPi; + Pi /= sumPi; + return h; +} -- cgit v1.2.3