From 4e0b46a727f6ea727b9e7920150609c58ce65fce Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Sat, 30 May 2015 02:00:08 -0300 Subject: Added tSNE. Code improvements. --- forceScheme.cpp | 4 +- lamp.cpp | 4 +- main.cpp | 1 + mp.h | 6 ++- pm.pro | 1 + scatterplot.cpp | 16 ++------ tsne.cpp | 124 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 140 insertions(+), 16 deletions(-) create mode 100644 tsne.cpp diff --git a/forceScheme.cpp b/forceScheme.cpp index de856ba..0dc7e34 100644 --- a/forceScheme.cpp +++ b/forceScheme.cpp @@ -3,6 +3,8 @@ #include #include +static const double EPSILON = 1e-3; + typedef arma::uvec V; arma::mat mp::forceScheme(const arma::mat &D, @@ -28,7 +30,7 @@ arma::mat mp::forceScheme(const arma::mat &D, continue; arma::rowvec direction(Y.row(*b) - Y.row(*a)); - double d2 = std::max(arma::norm(direction, 2), mp::EPSILON); + double d2 = std::max(arma::norm(direction, 2), EPSILON); double delta = (D(*a, *b) - d2) / fraction; deltaSum += fabs(delta); Y.row(*b) += delta * (direction / d2); diff --git a/lamp.cpp b/lamp.cpp index 90be140..b7a1027 100644 --- a/lamp.cpp +++ b/lamp.cpp @@ -2,6 +2,8 @@ #include +static const double EPSILON = 1e-3; + arma::mat mp::lamp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys) { arma::mat projection(X.n_rows, 2); @@ -22,7 +24,7 @@ void mp::lamp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::m arma::rowvec alphas(sampleSize); for (arma::uword j = 0; j < sampleSize; j++) { double dist = arma::accu(arma::square(Xs.row(j) - point)); - alphas[j] = 1. / std::max(dist, mp::EPSILON); + alphas[j] = 1. / std::max(dist, EPSILON); } double alphas_sum = arma::accu(alphas); diff --git a/main.cpp b/main.cpp index 2024ae7..5da297f 100644 --- a/main.cpp +++ b/main.cpp @@ -51,5 +51,6 @@ int main(int argc, char **argv) QObject::connect(interactionHandler.get(), SIGNAL(subsampleChanged(const arma::mat &)), plot, SLOT(setData(const arma::mat &))); interactionHandler.get()->setSubsample(Ys); + return app.exec(); } diff --git a/mp.h b/mp.h index 996da87..7d6b535 100644 --- a/mp.h +++ b/mp.h @@ -2,11 +2,15 @@ namespace mp { -static const double EPSILON = 1e-3; double euclidean(const arma::rowvec &x1, const arma::rowvec &x2); arma::mat dist(const arma::mat &X, double (*distCalc)(const arma::rowvec &, const arma::rowvec &) = euclidean); + arma::mat lamp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys); void lamp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys, arma::mat &Y); + arma::mat forceScheme(const arma::mat &D, arma::mat &Y, size_t maxIter = 20, double tol = 1e-3, double fraction = 8); +arma::mat tSNE(const arma::mat &X, arma::uword k = 2, double perplexity = 30, arma::uword nIter = 1000); +void tSNE(const arma::mat &X, arma::mat &Y, double perplexity = 30, arma::uword nIter = 1000); + } // namespace mp diff --git a/pm.pro b/pm.pro index af9a10d..3ca2aaa 100644 --- a/pm.pro +++ b/pm.pro @@ -12,6 +12,7 @@ SOURCES += main.cpp \ interactionhandler.cpp \ lamp.cpp \ forceScheme.cpp \ + tsne.cpp \ dist.cpp RESOURCES += pm.qrc diff --git a/scatterplot.cpp b/scatterplot.cpp index 8fb0b56..0c2fbc1 100644 --- a/scatterplot.cpp +++ b/scatterplot.cpp @@ -74,22 +74,12 @@ void updateCircleGeometry(QSGGeometry *geometry, float size, float cx, float cy) } } -void updateSquareGeometry(QSGGeometry *geometry, float size, float cx, float cy) -{ - float r = size / 2; - QSGGeometry::Point2D *vertexData = geometry->vertexDataAsPoint2D(); - vertexData[0].set(cx - r, cy - r); - vertexData[1].set(cx + r, cy - r); - vertexData[2].set(cx + r, cy + r); - vertexData[3].set(cx - r, cy + r); -} - -float Scatterplot::fromDataXToScreenX(float x) +inline float Scatterplot::fromDataXToScreenX(float x) { return PADDING + (x - m_xmin) / (m_xmax - m_xmin) * (width() - 2*PADDING); } -float Scatterplot::fromDataYToScreenY(float y) +inline float Scatterplot::fromDataYToScreenY(float y) { return PADDING + (y - m_ymin) / (m_ymax - m_ymin) * (height() - 2*PADDING); } @@ -111,7 +101,7 @@ QSGNode *Scatterplot::newGlyphNodeTree() { glyphNode->setMaterial(material); glyphNode->setFlag(QSGNode::OwnsMaterial); - // Place the glyph geometry node under a opacity node + // Place the glyph geometry node under an opacity node QSGOpacityNode *glyphOpacityNode = new QSGOpacityNode; glyphOpacityNode->appendChildNode(glyphNode); node->appendChildNode(glyphOpacityNode); diff --git a/tsne.cpp b/tsne.cpp new file mode 100644 index 0000000..50022eb --- /dev/null +++ b/tsne.cpp @@ -0,0 +1,124 @@ +#include "mp.h" + +#include +#include + +static const double ETA = 500; +static const double MIN_GAIN = 1e-2; +static const double EPSILON = 1e-12; +static const double INITIAL_MOMENTUM = 0.5; +static const double FINAL_MOMENTUM = 0.8; +static const double EARLY_EXAGGERATION = 4.; +static const double GAIN_FRACTION = 0.2; + +static const int MOMENTUM_THRESHOLD_ITER = 20; +static const int EXAGGERATION_THRESHOLD_ITER = 100; +static const int MAX_BINSEARCH_TRIES = 50; + +static void calcP(const arma::mat &X, arma::mat &P, double perplexity, double tol = 1e-5); +static double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi); + +arma::mat mp::tSNE(const arma::mat &X, arma::uword k, double perplexity, arma::uword nIter) +{ + arma::mat Y(X.n_rows, k); + mp::tSNE(X, Y, perplexity, nIter); + return Y; +} + +void mp::tSNE(const arma::mat &X, arma::mat &Y, double perplexity, arma::uword nIter) +{ + double momentum; + arma::uword n = X.n_rows; + arma::uword k = Y.n_cols; + arma::mat Q(n, n); + arma::mat dY(n, k), + gains(n, k, arma::fill::ones), + iY(n, k, arma::fill::zeros); + + arma::mat P(n, n, arma::fill::zeros); + calcP(X, P, perplexity); + P = (P + P.t()); + P /= arma::accu(P); + P *= EARLY_EXAGGERATION; + P.transform([](double v) { return std::max(v, EPSILON); }); + + for (arma::uword iter = 0; iter < nIter; iter++) { + arma::vec sumY = arma::sum(Y % Y, 1); + arma::mat num = -2. * (Y * Y.t()); + num.each_col() += sumY; + arma::inplace_trans(num); + num.each_col() += sumY; + num = 1. / (1. + num); + num.diag() *= 0; + Q = num / arma::accu(num); + Q.transform([](double v) { return std::max(v, EPSILON); }); + + for (arma::uword i = 0; i < n; i++) { + arma::mat tmp = -Y; + tmp.each_row() += Y.row(i); + tmp.each_col() %= (P.col(i) - Q.col(i)) % num.col(i); + dY.row(i) = arma::sum(tmp, 0); + } + + momentum = (iter < MOMENTUM_THRESHOLD_ITER) ? INITIAL_MOMENTUM : FINAL_MOMENTUM; + gains = (gains + GAIN_FRACTION) % ((dY > 0) != (iY > 0)) + + (gains * (1 - GAIN_FRACTION)) % ((dY > 0) == (iY > 0)); + gains.transform([](double v) { return std::max(v, MIN_GAIN); }); + iY = momentum * iY - ETA * (gains % dY); + Y += iY; + Y.each_row() -= mean(Y, 0); + + if (iter == EXAGGERATION_THRESHOLD_ITER) + P /= EARLY_EXAGGERATION; // remove early exaggeration + } +} + +static void calcP(const arma::mat &X, arma::mat &P, double perplexity, double tol) { + arma::colvec sumX = arma::sum(X % X, 1); + arma::mat D = -2 * (X * X.t()); + D.each_col() += sumX; + arma::inplace_trans(D); + D.each_col() += sumX; + D.diag() *= 0; + double logU = log(perplexity); + arma::rowvec beta(X.n_rows, arma::fill::ones); + + arma::rowvec Pi(X.n_rows); + for (arma::uword i = 0; i < X.n_rows; i++) { + double betaMin = -arma::datum::inf; + double betaMax = arma::datum::inf; + arma::rowvec Di = D.row(i); + double h = hBeta(Di, beta[i], Pi); + + double hDiff = h - logU; + for (int tries = 0; fabs(hDiff) > tol && tries < MAX_BINSEARCH_TRIES; tries++) { + if (hDiff > 0) { + betaMin = beta[i]; + if (betaMax == arma::datum::inf || betaMax == -arma::datum::inf) + beta[i] *= 2; + else + beta[i] = (beta[i] + betaMax) / 2.; + } else { + betaMax = beta[i]; + if (betaMin == arma::datum::inf || betaMin == -arma::datum::inf) + beta[i] /= 2; + else + beta[i] = (beta[i] + betaMin) / 2.; + } + + + h = hBeta(Di, beta[i], Pi); + hDiff = h - logU; + } + + P.row(i) = Pi; + } +} + +static inline double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi) { + Pi = arma::exp(-Di * beta); + double sumPi = arma::accu(Pi); + double h = log(sumPi) + beta * arma::accu(Di % Pi) / sumPi; + Pi /= sumPi; + return h; +} -- cgit v1.2.3