aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--interactionhandler.cpp4
-rw-r--r--interactionhandler.h2
-rw-r--r--knn.cpp30
-rw-r--r--main.cpp104
-rw-r--r--main_view.qml10
-rw-r--r--measures.cpp42
-rw-r--r--mp.h8
-rw-r--r--plmp.cpp22
-rw-r--r--pm.pro2
9 files changed, 167 insertions, 57 deletions
diff --git a/interactionhandler.cpp b/interactionhandler.cpp
index b79fdd8..2eb55bb 100644
--- a/interactionhandler.cpp
+++ b/interactionhandler.cpp
@@ -15,7 +15,11 @@ void InteractionHandler::setSubsample(const arma::mat &Ys)
{
switch (m_technique) {
case TECHNIQUE_PLMP:
+ mp::plmp(m_X, m_sampleIndices, Ys, m_Y);
+ break;
case TECHNIQUE_LSP:
+ // mp::lsp(m_X, m_sampleIndices, Ys, m_Y);
+ break;
case TECHNIQUE_LAMP:
mp::lamp(m_X, m_sampleIndices, Ys, m_Y);
break;
diff --git a/interactionhandler.h b/interactionhandler.h
index 0104d65..784746e 100644
--- a/interactionhandler.h
+++ b/interactionhandler.h
@@ -16,6 +16,8 @@ public:
InteractionHandler(const arma::mat &X, const arma::uvec &sampleIndices);
+ void setTechnique(InteractiveTechnique technique) { m_technique = technique; }
+
signals:
void subsampleChanged(const arma::mat &Y);
diff --git a/knn.cpp b/knn.cpp
new file mode 100644
index 0000000..579a76f
--- /dev/null
+++ b/knn.cpp
@@ -0,0 +1,30 @@
+#include "mp.h"
+
+void mp::knn(const arma::mat &dmat, arma::uword i, arma::uword k, arma::uvec &nn, arma::vec &dist)
+{
+ arma::uword n = dist.n_rows;
+ double dmax = arma::datum::inf;
+ nn.fill(i);
+ dist.fill(dmax);
+ if (k > n) {
+ return;
+ }
+
+ const arma::vec &dvec = dmat.col(i);
+ for (arma::uword j = 0; j < n; j++) {
+ if (j == i || dvec[j] > dmax) {
+ continue;
+ }
+
+ arma::uword l;
+ for (l = 0; dist[l] < dvec[j] && l < k; l++);
+ for (arma::uword m = k - 1; m > l; m--) {
+ nn[m] = nn[m - 1];
+ dist[m] = dist[m - 1];
+ }
+
+ nn[l] = j;
+ dist[l] = dvec[j];
+ dmax = dist[k - 1];
+ }
+}
diff --git a/main.cpp b/main.cpp
index 8355590..4aad1db 100644
--- a/main.cpp
+++ b/main.cpp
@@ -27,6 +27,69 @@ static QObject *mainProvider(QQmlEngine *engine, QJSEngine *scriptEngine)
return Main::instance();
}
+static inline double hBeta(const arma::rowvec &Di, double beta, arma::rowvec &Pi) {
+ Pi = arma::exp(-Di * beta);
+ double sumPi = arma::accu(Pi);
+ double h = log(sumPi) + beta * arma::accu(Di % Pi) / sumPi;
+ Pi /= sumPi;
+ return h;
+}
+
+static void calcP(const arma::mat &X, arma::mat &P, double perplexity = 30, double tol = 1e-5) {
+ arma::colvec sumX = arma::sum(X % X, 1);
+ arma::mat D = -2 * (X * X.t());
+ D.each_col() += sumX;
+ arma::inplace_trans(D);
+ D.each_col() += sumX;
+ D.diag() *= 0;
+ double logU = log(perplexity);
+ arma::rowvec beta(X.n_rows, arma::fill::ones);
+
+ arma::rowvec Pi(X.n_rows);
+ for (arma::uword i = 0; i < X.n_rows; i++) {
+ double betaMin = -arma::datum::inf;
+ double betaMax = arma::datum::inf;
+ arma::rowvec Di = D.row(i);
+ double h = hBeta(Di, beta[i], Pi);
+
+ double hDiff = h - logU;
+ for (int tries = 0; fabs(hDiff) > tol && tries < 50; tries++) {
+ if (hDiff > 0) {
+ betaMin = beta[i];
+ if (betaMax == arma::datum::inf || betaMax == -arma::datum::inf) {
+ beta[i] *= 2;
+ } else {
+ beta[i] = (beta[i] + betaMax) / 2.;
+ }
+ } else {
+ betaMax = beta[i];
+ if (betaMin == arma::datum::inf || betaMin == -arma::datum::inf) {
+ beta[i] /= 2;
+ } else {
+ beta[i] = (beta[i] + betaMin) / 2.;
+ }
+ }
+
+ h = hBeta(Di, beta[i], Pi);
+ hDiff = h - logU;
+ }
+
+ P.row(i) = Pi;
+ }
+}
+
+arma::uvec relevanceSampling(const arma::mat &X, int subsampleSize)
+{
+ arma::mat P(X.n_rows, X.n_rows);
+ calcP(X, P);
+ P = (P + P.t());
+ P /= arma::accu(P);
+ P.transform([](double p) { return std::max(p, 1e-12); });
+
+ arma::uvec indices = arma::sort_index(arma::sum(P));
+ return indices(arma::span(0, subsampleSize - 1));
+}
+
int main(int argc, char **argv)
{
QApplication app(argc, argv);
@@ -55,24 +118,40 @@ int main(int argc, char **argv)
}
Main *m = Main::instance();
- if (parser.isSet(indicesFileOutputOption)) {
- m->setIndicesSavePath(parser.value(indicesFileOutputOption));
- }
- if (parser.isSet(subsampleFileOutputOption)) {
- m->setSubsampleSavePath(parser.value(subsampleFileOutputOption));
- }
m->loadDataset(args[0].toStdString());
-
arma::mat X = m->X();
arma::vec labels = m->labels();
arma::uword n = X.n_rows;
- arma::uword subsampleSize = (arma::uword) n / 10.f;
- arma::uvec sampleIndices = arma::randi<arma::uvec>(subsampleSize, arma::distr_param(0, n-1));
- m->setSubsampleIndices(sampleIndices);
+ int subsampleSize = 3 * ((int) sqrt(n));
+ arma::uvec sampleIndices;
+ arma::mat Ys;
+
+ if (parser.isSet(indicesFileOutputOption)) {
+ const QString &indicesFilename = parser.value(indicesFileOutputOption);
+ m->setIndicesSavePath(indicesFilename);
+ QFile indicesFile(indicesFilename);
+ if (indicesFile.exists()) {
+ sampleIndices.load(indicesFilename.toStdString(), arma::raw_ascii);
+ subsampleSize = sampleIndices.n_elem;
+ } else {
+ sampleIndices = relevanceSampling(X, subsampleSize);
+ }
+ }
+ if (parser.isSet(subsampleFileOutputOption)) {
+ const QString &subsampleFilename = parser.value(subsampleFileOutputOption);
+ m->setSubsampleSavePath(subsampleFilename);
+ QFile subsampleFile(subsampleFilename);
+ if (subsampleFile.exists()) {
+ Ys.load(subsampleFilename.toStdString(), arma::raw_ascii);
+ } else {
+ Ys.set_size(subsampleSize, 2);
+ mp::forceScheme(mp::dist(X.rows(sampleIndices)), Ys);
+ }
+ }
- arma::mat Ys(subsampleSize, 2, arma::fill::randn);
- mp::forceScheme(mp::dist(X.rows(sampleIndices)), Ys);
+ m->setSubsampleIndices(sampleIndices);
+ m->setSubsample(Ys);
qmlRegisterType<Scatterplot>("PM", 1, 0, "Scatterplot");
qmlRegisterType<HistoryGraph>("PM", 1, 0, "HistoryGraph");
@@ -113,6 +192,7 @@ int main(int argc, char **argv)
// Update LAMP projection as the subsample is modified
InteractionHandler interactionHandler(X, sampleIndices);
+ interactionHandler.setTechnique(InteractionHandler::TECHNIQUE_LAMP);
QObject::connect(subsamplePlot, SIGNAL(xyChanged(const arma::mat &)),
&interactionHandler, SLOT(setSubsample(const arma::mat &)));
QObject::connect(subsamplePlot, SIGNAL(xyInteractivelyChanged(const arma::mat &)),
diff --git a/main_view.qml b/main_view.qml
index de9ec64..0cd29b5 100644
--- a/main_view.qml
+++ b/main_view.qml
@@ -21,7 +21,7 @@ ApplicationWindow {
Menu {
title: "View"
MenuItem {
- action: noneColorAction
+ action: labelColorAction
exclusiveGroup: coloringGroup
}
MenuItem {
@@ -111,12 +111,12 @@ ApplicationWindow {
id: coloringGroup
Action {
- id: noneColorAction
- text: "None"
- shortcut: "Shift+O"
+ id: labelColorAction
+ text: "Labels"
+ shortcut: "Shift+L"
checked: true
checkable: true
- onTriggered: console.log("None")
+ onTriggered: console.log("Labels")
}
Action {
diff --git a/measures.cpp b/measures.cpp
index a030096..daf3c78 100644
--- a/measures.cpp
+++ b/measures.cpp
@@ -4,44 +4,6 @@
#include <cmath>
#include <algorithm>
-static
-void knn(const arma::mat &dmat, arma::uword i, arma::uword k, arma::uvec &nn, arma::vec &dist)
-{
- arma::uword n = dist.n_rows;
- if (k > n) {
- return;
- }
-
- for (arma::uword j = 0, l = 0; l < k; j++, l++) {
- if (j == i) {
- j++;
- }
-
- nn[l] = j;
- dist[l] = dmat(i, j);
- }
-
- double dmax = *std::max_element(dist.begin(), dist.end());
- for (arma::uword j = 0; j < n; j++) {
- if (j == i) {
- continue;
- }
-
- if (dmat(i, j) < dmax) {
- dmax = dmat(i, j);
- arma::uword l;
- for (l = 0; dmat(i, j) > dist[l] && l < k; l++);
- for (arma::uword m = l + 1; m < k; m++) {
- nn[m] = nn[m - 1];
- dist[m] = dist[m - 1];
- }
-
- nn[l] = j;
- dist[l] = dmat(i, j);
- }
- }
-}
-
arma::vec mp::neighborhoodPreservation(const arma::mat &distA,
const arma::mat &distB,
arma::uword k)
@@ -55,8 +17,8 @@ arma::vec mp::neighborhoodPreservation(const arma::mat &distA,
arma::uvec nnB(k);
arma::vec dist(k);
- knn(distA, i, k, nnA, dist);
- knn(distB, i, k, nnB, dist);
+ mp::knn(distA, i, k, nnA, dist);
+ mp::knn(distB, i, k, nnB, dist);
std::sort(nnA.begin(), nnA.end());
std::sort(nnB.begin(), nnB.end());
diff --git a/mp.h b/mp.h
index 192f106..fe98a1c 100644
--- a/mp.h
+++ b/mp.h
@@ -10,6 +10,8 @@ typedef double (*DistFunc)(const arma::rowvec &, const arma::rowvec &);
double euclidean(const arma::rowvec &x1, const arma::rowvec &x2);
arma::mat dist(const arma::mat &X, DistFunc dfunc = euclidean);
+void knn(const arma::mat &dmat, arma::uword i, arma::uword k, arma::uvec &nn, arma::vec &dist);
+
// Evaluation measures
typedef arma::vec (*MeasureFunc)(const arma::mat &distA, const arma::mat &distB);
arma::vec neighborhoodPreservation(const arma::mat &distA, const arma::mat &distB, arma::uword k = 10);
@@ -19,6 +21,12 @@ arma::vec silhouette(const arma::mat &distA, const arma::mat &distB, const arma:
arma::mat lamp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys);
void lamp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys, arma::mat &Y);
+arma::mat plmp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys);
+void plmp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys, arma::mat &Y);
+
+//arma::mat lsp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys, int k = 15);
+//void lsp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys, arma::mat &Y, int k = 15);
+
arma::mat forceScheme(const arma::mat &D, arma::mat &Y, size_t maxIter = 20, double tol = 1e-3, double fraction = 8);
arma::mat tSNE(const arma::mat &X, arma::uword k = 2, double perplexity = 30, arma::uword nIter = 1000);
diff --git a/plmp.cpp b/plmp.cpp
new file mode 100644
index 0000000..0ca1b83
--- /dev/null
+++ b/plmp.cpp
@@ -0,0 +1,22 @@
+
+#include "mp.h"
+
+arma::mat mp::plmp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys)
+{
+ arma::mat Y(X.n_rows, Ys.n_cols);
+ mp::plmp(X, sampleIndices, Ys, Y);
+ return Y;
+}
+
+void mp::plmp(const arma::mat &X, const arma::uvec &sampleIndices, const arma::mat &Ys, arma::mat &Y)
+{
+ arma::mat Xs = X.rows(sampleIndices);
+ Xs.each_row() -= arma::mean(Xs);
+ arma::mat lYs = Ys;
+ lYs.each_row() -= arma::mean(Ys);
+ const arma::mat &Xst = Xs.t();
+ arma::mat P = arma::solve(Xst * Xs, Xst * lYs);
+
+ Y = X * P;
+ Y.rows(sampleIndices) = lYs;
+}
diff --git a/pm.pro b/pm.pro
index 281cb55..88a4b61 100644
--- a/pm.pro
+++ b/pm.pro
@@ -28,6 +28,8 @@ SOURCES += main.cpp \
distortionobserver.cpp \
npdistortion.cpp \
lamp.cpp \
+ plmp.cpp \
+ knn.cpp \
forceScheme.cpp \
tsne.cpp \
measures.cpp \