From 2355b010723d5479f71f8c02dac752fbbd03ebfa Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Wed, 20 Jan 2016 10:32:19 +0100 Subject: main(): better code for loading CP data. --- main.cpp | 52 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 30 insertions(+), 22 deletions(-) (limited to 'main.cpp') diff --git a/main.cpp b/main.cpp index febeea2..45fef63 100644 --- a/main.cpp +++ b/main.cpp @@ -28,6 +28,12 @@ static QObject *mainProvider(QQmlEngine *engine, QJSEngine *scriptEngine) return Main::instance(); } +arma::uvec extractCPs(const arma::mat &X) +{ + int numCPs = (int) (3 * sqrt(X.n_rows)); + return arma::randi(numCPs, arma::distr_param(0, X.n_rows - 1)); +} + int main(int argc, char **argv) { QApplication app(argc, argv); @@ -61,36 +67,38 @@ int main(int argc, char **argv) arma::mat X = m->X(); arma::vec labels = m->labels(); - arma::uword n = X.n_rows; - int cpSize; arma::uvec cpIndices; arma::mat Ys; + QString indicesFilename; if (parser.isSet(indicesFileOutputOption)) { - const QString &indicesFilename = parser.value(indicesFileOutputOption); + indicesFilename = parser.value(indicesFileOutputOption); + } + QFile indicesFile(indicesFilename); + if (indicesFile.exists()) { m->setIndicesSavePath(indicesFilename); - QFile indicesFile(indicesFilename); - if (indicesFile.exists()) { - cpIndices.load(indicesFilename.toStdString(), arma::raw_ascii); - cpSize = cpIndices.n_elem; - } else { - cpSize = (int) (3 * sqrt(n)); - // cpIndices = relevanceSampling(X, cpSize); - cpIndices = arma::randi(cpSize, arma::distr_param(0, n-1)); - } - - arma::sort(cpIndices); + cpIndices.load(indicesFilename.toStdString(), arma::raw_ascii); + } else { + cpIndices = extractCPs(X); } + + arma::sort(cpIndices); + + QString cpFilename; if (parser.isSet(cpFileOutputOption)) { - const QString &cpFilename = parser.value(cpFileOutputOption); + cpFilename = parser.value(cpFileOutputOption); + } + QFile cpFile(cpFilename); + if (cpFile.exists()) { m->setCPSavePath(cpFilename); - QFile cpFile(cpFilename); - if (cpFile.exists()) { - Ys.load(cpFilename.toStdString(), arma::raw_ascii); - } else { - Ys.set_size(cpSize, 2); - mp::forceScheme(mp::dist(X.rows(cpIndices)), Ys); - } + Ys.load(cpFilename.toStdString(), arma::raw_ascii); + } else { + Ys.set_size(cpIndices.n_elem, 2); + mp::forceScheme(mp::dist(X.rows(cpIndices)), Ys); + } + if (cpIndices.n_elem != Ys.n_rows) { + std::cerr << "The number of CP indices and the CP map do not match." << std::endl; + return 1; } m->setCPIndices(cpIndices); -- cgit v1.2.3