aboutsummaryrefslogtreecommitdiff
path: root/tsne.cpp
blob: 50022eb8010aeb529be215722bd688396f26d006 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "mp.h"

#include <algorithm>
#include <cmath>

static const double ETA = 500;
static const double MIN_GAIN           = 1e-2;
static const double EPSILON            = 1e-12;
static const double INITIAL_MOMENTUM   = 0.5;
static const double FINAL_MOMENTUM     = 0.8;
static const double EARLY_EXAGGERATION = 4.;
static const double GAIN_FRACTION      = 0.2;

static const int MOMENTUM_THRESHOLD_ITER     = 20;
static const int EXAGGERATION_THRESHOLD_ITER = 100;
static const int MAX_BINSEARCH_TRIES         = 50;

static void calcP(const arma::mat &X, arma::mat &P, double perplexity, double tol = 1e-5);
static double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi);

arma::mat mp::tSNE(const arma::mat &X, arma::uword k, double perplexity, arma::uword nIter)
{
    arma::mat Y(X.n_rows, k);
    mp::tSNE(X, Y, perplexity, nIter);
    return Y;
}

void mp::tSNE(const arma::mat &X, arma::mat &Y, double perplexity, arma::uword nIter)
{
    double momentum;
    arma::uword n = X.n_rows;
    arma::uword k = Y.n_cols;
    arma::mat Q(n, n);
    arma::mat dY(n, k),
              gains(n, k, arma::fill::ones),
              iY(n, k, arma::fill::zeros);

    arma::mat P(n, n, arma::fill::zeros);
    calcP(X, P, perplexity);
    P = (P + P.t());
    P /= arma::accu(P);
    P *= EARLY_EXAGGERATION;
    P.transform([](double v) { return std::max(v, EPSILON); });

    for (arma::uword iter = 0; iter < nIter; iter++) {
        arma::vec sumY = arma::sum(Y % Y, 1);
        arma::mat num = -2. * (Y * Y.t());
        num.each_col() += sumY;
        arma::inplace_trans(num);
        num.each_col() += sumY;
        num = 1. / (1. + num);
        num.diag() *= 0;
        Q = num / arma::accu(num);
        Q.transform([](double v) { return std::max(v, EPSILON); });

        for (arma::uword i = 0; i < n; i++) {
            arma::mat tmp = -Y;
            tmp.each_row() += Y.row(i);
            tmp.each_col() %= (P.col(i) - Q.col(i)) % num.col(i);
            dY.row(i) = arma::sum(tmp, 0);
        }

        momentum = (iter < MOMENTUM_THRESHOLD_ITER) ? INITIAL_MOMENTUM : FINAL_MOMENTUM;
        gains = (gains +       GAIN_FRACTION) % ((dY > 0) != (iY > 0))
              + (gains * (1 - GAIN_FRACTION)) % ((dY > 0) == (iY > 0));
        gains.transform([](double v) { return std::max(v, MIN_GAIN); });
        iY = momentum * iY - ETA * (gains % dY);
        Y += iY;
        Y.each_row() -= mean(Y, 0);

        if (iter == EXAGGERATION_THRESHOLD_ITER)
            P /= EARLY_EXAGGERATION; // remove early exaggeration
    }
}

static void calcP(const arma::mat &X, arma::mat &P, double perplexity, double tol) {
    arma::colvec sumX = arma::sum(X % X, 1);
    arma::mat D = -2 * (X * X.t());
    D.each_col() += sumX;
    arma::inplace_trans(D);
    D.each_col() += sumX;
    D.diag() *= 0;
    double logU = log(perplexity);
    arma::rowvec beta(X.n_rows, arma::fill::ones);

    arma::rowvec Pi(X.n_rows);
    for (arma::uword i = 0; i < X.n_rows; i++) {
        double betaMin = -arma::datum::inf;
        double betaMax =  arma::datum::inf;
        arma::rowvec Di = D.row(i);
        double h = hBeta(Di, beta[i], Pi);

        double hDiff = h - logU;
        for (int tries = 0; fabs(hDiff) > tol && tries < MAX_BINSEARCH_TRIES; tries++) {
            if (hDiff > 0) {
                betaMin = beta[i];
                if (betaMax == arma::datum::inf || betaMax == -arma::datum::inf)
                    beta[i] *= 2;
                else
                    beta[i] = (beta[i] + betaMax) / 2.;
            } else {
                betaMax = beta[i];
                if (betaMin == arma::datum::inf || betaMin == -arma::datum::inf)
                    beta[i] /= 2;
                else
                    beta[i] = (beta[i] + betaMin) / 2.;
            }

            
            h = hBeta(Di, beta[i], Pi);
            hDiff = h - logU;
        }

        P.row(i) = Pi;
    }
}

static inline double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi) {
    Pi = arma::exp(-Di * beta);
    double sumPi = arma::accu(Pi);
    double h = log(sumPi) + beta * arma::accu(Di % Pi) / sumPi;
    Pi /= sumPi;
    return h;
}