-
Notifications
You must be signed in to change notification settings - Fork 19
/
picpac-kfold.cpp
96 lines (86 loc) · 2.74 KB
/
picpac-kfold.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <boost/accumulators/accumulators.hpp>
#include <boost/accumulators/statistics/stats.hpp>
#include <boost/accumulators/statistics/mean.hpp>
#include <boost/accumulators/statistics/min.hpp>
#include <boost/accumulators/statistics/max.hpp>
#include <boost/accumulators/statistics/variance.hpp>
#include <boost/filesystem/fstream.hpp>
#include <boost/program_options.hpp>
#include "picpac-cv.h"
using namespace std;
using namespace picpac;
namespace ba = boost::accumulators;
int main(int argc, char const* argv[]) {
fs::path input_path;
fs::path output_path;
Stream::Config config;
unsigned max_test;
namespace po = boost::program_options;
po::options_description desc("Allowed options");
desc.add_options()
("help,h", "produce help message.")
("input", po::value(&input_path), "")
("output", po::value(&output_path), "")
("seed", po::value(&config.seed), "")
("split", po::value(&config.split)->default_value(5), "")
("fold", po::value(&config.split_fold)->default_value(0), "")
("stratify", po::value(&config.stratify), "")
("max-test", po::value(&max_test)->default_value(0), "0 for no limit")
;
po::positional_options_description p;
p.add("input", 1);
p.add("output", 1);
po::variables_map vm;
po::store(po::command_line_parser(argc, argv).
options(desc).positional(p).run(), vm);
po::notify(vm);
if (vm.count("help") || input_path.empty() || output_path.empty()) {
cout << "Usage:" << endl;
cout << "\tpicpac-split <output> <input> [<input> ...]" << endl;
cout << desc;
cout << endl;
return 0;
}
fs::path train_path(output_path);
train_path += ".train";
fs::path test_path(output_path);
test_path += ".test";
FileWriter train(train_path, FileWriter::COMPACT);
FileWriter test(test_path, FileWriter::COMPACT);
config.loop = false;
config.shuffle = true;
{
config.split_negate = false;
Stream str(input_path, config);
for (;;) {
Record rec;
try {
str.read_next(&rec);
}
catch (EoS const &) {
break;
}
train.append(rec);
}
}
{
config.split_negate = true;
Stream str(input_path, config);
for (unsigned cc = 0;; ++cc) {
Record rec;
try {
str.read_next(&rec);
}
catch (EoS const &) {
break;
}
if ((max_test == 0) || (cc < max_test)) {
test.append(rec);
}
else {
train.append(rec);
}
}
}
return 0;
}