aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.cpp52
1 files changed, 30 insertions, 22 deletions
diff --git a/main.cpp b/main.cpp
index febeea2..45fef63 100644
--- a/main.cpp
+++ b/main.cpp
@@ -28,6 +28,12 @@ static QObject *mainProvider(QQmlEngine *engine, QJSEngine *scriptEngine)
return Main::instance();
}
+arma::uvec extractCPs(const arma::mat &X)
+{
+ int numCPs = (int) (3 * sqrt(X.n_rows));
+ return arma::randi<arma::uvec>(numCPs, arma::distr_param(0, X.n_rows - 1));
+}
+
int main(int argc, char **argv)
{
QApplication app(argc, argv);
@@ -61,36 +67,38 @@ int main(int argc, char **argv)
arma::mat X = m->X();
arma::vec labels = m->labels();
- arma::uword n = X.n_rows;
- int cpSize;
arma::uvec cpIndices;
arma::mat Ys;
+ QString indicesFilename;
if (parser.isSet(indicesFileOutputOption)) {
- const QString &indicesFilename = parser.value(indicesFileOutputOption);
+ indicesFilename = parser.value(indicesFileOutputOption);
+ }
+ QFile indicesFile(indicesFilename);
+ if (indicesFile.exists()) {
m->setIndicesSavePath(indicesFilename);
- QFile indicesFile(indicesFilename);
- if (indicesFile.exists()) {
- cpIndices.load(indicesFilename.toStdString(), arma::raw_ascii);
- cpSize = cpIndices.n_elem;
- } else {
- cpSize = (int) (3 * sqrt(n));
- // cpIndices = relevanceSampling(X, cpSize);
- cpIndices = arma::randi<arma::uvec>(cpSize, arma::distr_param(0, n-1));
- }
-
- arma::sort(cpIndices);
+ cpIndices.load(indicesFilename.toStdString(), arma::raw_ascii);
+ } else {
+ cpIndices = extractCPs(X);
}
+
+ arma::sort(cpIndices);
+
+ QString cpFilename;
if (parser.isSet(cpFileOutputOption)) {
- const QString &cpFilename = parser.value(cpFileOutputOption);
+ cpFilename = parser.value(cpFileOutputOption);
+ }
+ QFile cpFile(cpFilename);
+ if (cpFile.exists()) {
m->setCPSavePath(cpFilename);
- QFile cpFile(cpFilename);
- if (cpFile.exists()) {
- Ys.load(cpFilename.toStdString(), arma::raw_ascii);
- } else {
- Ys.set_size(cpSize, 2);
- mp::forceScheme(mp::dist(X.rows(cpIndices)), Ys);
- }
+ Ys.load(cpFilename.toStdString(), arma::raw_ascii);
+ } else {
+ Ys.set_size(cpIndices.n_elem, 2);
+ mp::forceScheme(mp::dist(X.rows(cpIndices)), Ys);
+ }
+ if (cpIndices.n_elem != Ys.n_rows) {
+ std::cerr << "The number of CP indices and the CP map do not match." << std::endl;
+ return 1;
}
m->setCPIndices(cpIndices);