From ad4fbabeca2cbdf4cb47f1a923183027494ab0a8 Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Thu, 22 Oct 2015 16:39:15 -0200 Subject: Added PLMP as an alternative technique to use; knn() is now exposed in the mp namespace. --- main.cpp | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 92 insertions(+), 12 deletions(-) (limited to 'main.cpp') diff --git a/main.cpp b/main.cpp index 8355590..4aad1db 100644 --- a/main.cpp +++ b/main.cpp @@ -27,6 +27,69 @@ static QObject *mainProvider(QQmlEngine *engine, QJSEngine *scriptEngine) return Main::instance(); } +static inline double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi) { + Pi = arma::exp(-Di * beta); + double sumPi = arma::accu(Pi); + double h = log(sumPi) + beta * arma::accu(Di % Pi) / sumPi; + Pi /= sumPi; + return h; +} + +static void calcP(const arma::mat &X, arma::mat &P, double perplexity = 30, double tol = 1e-5) { + arma::colvec sumX = arma::sum(X % X, 1); + arma::mat D = -2 * (X * X.t()); + D.each_col() += sumX; + arma::inplace_trans(D); + D.each_col() += sumX; + D.diag() *= 0; + double logU = log(perplexity); + arma::rowvec beta(X.n_rows, arma::fill::ones); + + arma::rowvec Pi(X.n_rows); + for (arma::uword i = 0; i < X.n_rows; i++) { + double betaMin = -arma::datum::inf; + double betaMax = arma::datum::inf; + arma::rowvec Di = D.row(i); + double h = hBeta(Di, beta[i], Pi); + + double hDiff = h - logU; + for (int tries = 0; fabs(hDiff) > tol && tries < 50; tries++) { + if (hDiff > 0) { + betaMin = beta[i]; + if (betaMax == arma::datum::inf || betaMax == -arma::datum::inf) { + beta[i] *= 2; + } else { + beta[i] = (beta[i] + betaMax) / 2.; + } + } else { + betaMax = beta[i]; + if (betaMin == arma::datum::inf || betaMin == -arma::datum::inf) { + beta[i] /= 2; + } else { + beta[i] = (beta[i] + betaMin) / 2.; + } + } + + h = hBeta(Di, beta[i], Pi); + hDiff = h - logU; + } + + P.row(i) = Pi; + } +} + +arma::uvec relevanceSampling(const arma::mat &X, int subsampleSize) +{ + arma::mat P(X.n_rows, X.n_rows); + calcP(X, P); + P = (P + P.t()); + P /= arma::accu(P); + P.transform([](double p) { return std::max(p, 1e-12); }); + + arma::uvec indices = arma::sort_index(arma::sum(P)); + return indices(arma::span(0, subsampleSize - 1)); +} + int main(int argc, char **argv) { QApplication app(argc, argv); @@ -55,24 +118,40 @@ int main(int argc, char **argv) } Main *m = Main::instance(); - if (parser.isSet(indicesFileOutputOption)) { - m->setIndicesSavePath(parser.value(indicesFileOutputOption)); - } - if (parser.isSet(subsampleFileOutputOption)) { - m->setSubsampleSavePath(parser.value(subsampleFileOutputOption)); - } m->loadDataset(args[0].toStdString()); - arma::mat X = m->X(); arma::vec labels = m->labels(); arma::uword n = X.n_rows; - arma::uword subsampleSize = (arma::uword) n / 10.f; - arma::uvec sampleIndices = arma::randi(subsampleSize, arma::distr_param(0, n-1)); - m->setSubsampleIndices(sampleIndices); + int subsampleSize = 3 * ((int) sqrt(n)); + arma::uvec sampleIndices; + arma::mat Ys; + + if (parser.isSet(indicesFileOutputOption)) { + const QString &indicesFilename = parser.value(indicesFileOutputOption); + m->setIndicesSavePath(indicesFilename); + QFile indicesFile(indicesFilename); + if (indicesFile.exists()) { + sampleIndices.load(indicesFilename.toStdString(), arma::raw_ascii); + subsampleSize = sampleIndices.n_elem; + } else { + sampleIndices = relevanceSampling(X, subsampleSize); + } + } + if (parser.isSet(subsampleFileOutputOption)) { + const QString &subsampleFilename = parser.value(subsampleFileOutputOption); + m->setSubsampleSavePath(subsampleFilename); + QFile subsampleFile(subsampleFilename); + if (subsampleFile.exists()) { + Ys.load(subsampleFilename.toStdString(), arma::raw_ascii); + } else { + Ys.set_size(subsampleSize, 2); + mp::forceScheme(mp::dist(X.rows(sampleIndices)), Ys); + } + } - arma::mat Ys(subsampleSize, 2, arma::fill::randn); - mp::forceScheme(mp::dist(X.rows(sampleIndices)), Ys); + m->setSubsampleIndices(sampleIndices); + m->setSubsample(Ys); qmlRegisterType("PM", 1, 0, "Scatterplot"); qmlRegisterType("PM", 1, 0, "HistoryGraph"); @@ -113,6 +192,7 @@ int main(int argc, char **argv) // Update LAMP projection as the subsample is modified InteractionHandler interactionHandler(X, sampleIndices); + interactionHandler.setTechnique(InteractionHandler::TECHNIQUE_LAMP); QObject::connect(subsamplePlot, SIGNAL(xyChanged(const arma::mat &)), &interactionHandler, SLOT(setSubsample(const arma::mat &))); QObject::connect(subsamplePlot, SIGNAL(xyInteractivelyChanged(const arma::mat &)), -- cgit v1.2.3