aboutsummaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2015-10-22 16:39:15 -0200
committerSamuel Fadel <samuelfadel@gmail.com>2015-10-22 16:39:15 -0200
commitad4fbabeca2cbdf4cb47f1a923183027494ab0a8 (patch)
tree4a9c16454ff4d802b23a8271f4005ee1846f7b1b /main.cpp
parent99ac0af03e1695ba4de2c42e949fce61b84850e5 (diff)
Added PLMP as an alternative technique to use; knn() is now exposed in the mp namespace.
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp104
1 files changed, 92 insertions, 12 deletions
diff --git a/main.cpp b/main.cpp
index 8355590..4aad1db 100644
--- a/main.cpp
+++ b/main.cpp
@@ -27,6 +27,69 @@ static QObject *mainProvider(QQmlEngine *engine, QJSEngine *scriptEngine)
return Main::instance();
}
+static inline double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi) {
+ Pi = arma::exp(-Di * beta);
+ double sumPi = arma::accu(Pi);
+ double h = log(sumPi) + beta * arma::accu(Di % Pi) / sumPi;
+ Pi /= sumPi;
+ return h;
+}
+
+static void calcP(const arma::mat &X, arma::mat &P, double perplexity = 30, double tol = 1e-5) {
+ arma::colvec sumX = arma::sum(X % X, 1);
+ arma::mat D = -2 * (X * X.t());
+ D.each_col() += sumX;
+ arma::inplace_trans(D);
+ D.each_col() += sumX;
+ D.diag() *= 0;
+ double logU = log(perplexity);
+ arma::rowvec beta(X.n_rows, arma::fill::ones);
+
+ arma::rowvec Pi(X.n_rows);
+ for (arma::uword i = 0; i < X.n_rows; i++) {
+ double betaMin = -arma::datum::inf;
+ double betaMax = arma::datum::inf;
+ arma::rowvec Di = D.row(i);
+ double h = hBeta(Di, beta[i], Pi);
+
+ double hDiff = h - logU;
+ for (int tries = 0; fabs(hDiff) > tol && tries < 50; tries++) {
+ if (hDiff > 0) {
+ betaMin = beta[i];
+ if (betaMax == arma::datum::inf || betaMax == -arma::datum::inf) {
+ beta[i] *= 2;
+ } else {
+ beta[i] = (beta[i] + betaMax) / 2.;
+ }
+ } else {
+ betaMax = beta[i];
+ if (betaMin == arma::datum::inf || betaMin == -arma::datum::inf) {
+ beta[i] /= 2;
+ } else {
+ beta[i] = (beta[i] + betaMin) / 2.;
+ }
+ }
+
+ h = hBeta(Di, beta[i], Pi);
+ hDiff = h - logU;
+ }
+
+ P.row(i) = Pi;
+ }
+}
+
+arma::uvec relevanceSampling(const arma::mat &X, int subsampleSize)
+{
+ arma::mat P(X.n_rows, X.n_rows);
+ calcP(X, P);
+ P = (P + P.t());
+ P /= arma::accu(P);
+ P.transform([](double p) { return std::max(p, 1e-12); });
+
+ arma::uvec indices = arma::sort_index(arma::sum(P));
+ return indices(arma::span(0, subsampleSize - 1));
+}
+
int main(int argc, char **argv)
{
QApplication app(argc, argv);
@@ -55,24 +118,40 @@ int main(int argc, char **argv)
}
Main *m = Main::instance();
- if (parser.isSet(indicesFileOutputOption)) {
- m->setIndicesSavePath(parser.value(indicesFileOutputOption));
- }
- if (parser.isSet(subsampleFileOutputOption)) {
- m->setSubsampleSavePath(parser.value(subsampleFileOutputOption));
- }
m->loadDataset(args[0].toStdString());
-
arma::mat X = m->X();
arma::vec labels = m->labels();
arma::uword n = X.n_rows;
- arma::uword subsampleSize = (arma::uword) n / 10.f;
- arma::uvec sampleIndices = arma::randi<arma::uvec>(subsampleSize, arma::distr_param(0, n-1));
- m->setSubsampleIndices(sampleIndices);
+ int subsampleSize = 3 * ((int) sqrt(n));
+ arma::uvec sampleIndices;
+ arma::mat Ys;
+
+ if (parser.isSet(indicesFileOutputOption)) {
+ const QString &indicesFilename = parser.value(indicesFileOutputOption);
+ m->setIndicesSavePath(indicesFilename);
+ QFile indicesFile(indicesFilename);
+ if (indicesFile.exists()) {
+ sampleIndices.load(indicesFilename.toStdString(), arma::raw_ascii);
+ subsampleSize = sampleIndices.n_elem;
+ } else {
+ sampleIndices = relevanceSampling(X, subsampleSize);
+ }
+ }
+ if (parser.isSet(subsampleFileOutputOption)) {
+ const QString &subsampleFilename = parser.value(subsampleFileOutputOption);
+ m->setSubsampleSavePath(subsampleFilename);
+ QFile subsampleFile(subsampleFilename);
+ if (subsampleFile.exists()) {
+ Ys.load(subsampleFilename.toStdString(), arma::raw_ascii);
+ } else {
+ Ys.set_size(subsampleSize, 2);
+ mp::forceScheme(mp::dist(X.rows(sampleIndices)), Ys);
+ }
+ }
- arma::mat Ys(subsampleSize, 2, arma::fill::randn);
- mp::forceScheme(mp::dist(X.rows(sampleIndices)), Ys);
+ m->setSubsampleIndices(sampleIndices);
+ m->setSubsample(Ys);
qmlRegisterType<Scatterplot>("PM", 1, 0, "Scatterplot");
qmlRegisterType<HistoryGraph>("PM", 1, 0, "HistoryGraph");
@@ -113,6 +192,7 @@ int main(int argc, char **argv)
// Update LAMP projection as the subsample is modified
InteractionHandler interactionHandler(X, sampleIndices);
+ interactionHandler.setTechnique(InteractionHandler::TECHNIQUE_LAMP);
QObject::connect(subsamplePlot, SIGNAL(xyChanged(const arma::mat &)),
&interactionHandler, SLOT(setSubsample(const arma::mat &)));
QObject::connect(subsamplePlot, SIGNAL(xyInteractivelyChanged(const arma::mat &)),