forked from aaalgo/picpac
-
Notifications
You must be signed in to change notification settings - Fork 0
/
picpac.h
709 lines (653 loc) · 22.7 KB
/
picpac.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
#pragma once
#include <array>
#include <vector>
#include <string>
#include <mutex>
#include <condition_variable>
#include <thread>
#include <algorithm>
#include <limits>
#include <stdexcept>
#include <random>
#include <functional>
#include <boost/filesystem.hpp>
#include <boost/lexical_cast.hpp>
#include <boost/asio/buffer.hpp>
#include <glog/logging.h>
namespace picpac {
using std::array;
using std::vector;
using std::string;
using std::numeric_limits;
using std::runtime_error;
using boost::lexical_cast;
using boost::asio::const_buffer;
typedef std::lock_guard<std::mutex> lock_guard;
typedef std::unique_lock<std::mutex> unique_lock;
namespace fs = boost::filesystem;
/// Static coded maximal number of fields per record
static constexpr unsigned MAX_FIELDS = 6;
/// Static coded segment header size
static constexpr unsigned SEGMENT_HEADER_SIZE = 8192;
/// Static coded maximal number of records per segment
/**
* This must be carefully calculated so that the segment
* header struct adds up to SEGMENT_HEADER_SIZE.
*/
static constexpr unsigned MAX_SEG_RECORDS = 1020;
/// Record alignment for faster access
static constexpr unsigned RECORD_ALIGN = 4096;
/// Static coded maximal record size
static constexpr size_t MAX_RECORD_SIZE = 512*1024*1024; // 512MB
static_assert(MAX_RECORD_SIZE < numeric_limits<int32_t>::max(), "record too large");
/// Maximal number of categories.
static constexpr unsigned MAX_CATEGORIES = 2000;
/* Maximal category ID is (MAX_CATEGORIES - 1)
* we need to make sure this can be stored in float without
* loss of precision. If it's too big, the LSB 1 will be lost.
*/
static constexpr unsigned MAX_CATEGORY_TEST = (MAX_CATEGORIES - 1) | 1;
static_assert(float(MAX_CATEGORY_TEST) == MAX_CATEGORY_TEST, "too many categories");
static constexpr int DEFAULT_SEED = 2016;
static constexpr unsigned DEFAULT_PRELOAD = 256;
static constexpr unsigned DEFAULT_THREADS = 4;
// format stack trace
class Stack: public std::vector<char const *> {
static const unsigned MAX_BACKTRACE = 100;
char **symbols;
public:
Stack ();
~Stack ();
string format (std::string const &prefix = "") const;
};
enum FieldType { // Record field type
FIELD_NONE = 0,
/*
FIELD_FILE = 1,
FIELD_TEXT = 2,
FIELD_OTHER = 3
*/
CHECK_FIELD_SIZE
};
static_assert(CHECK_FIELD_SIZE - 1 <= numeric_limits<uint8_t>::max(), "Too many field types");
class BadLabel: public runtime_error {
public:
BadLabel (int l): runtime_error(lexical_cast<string>(l)) {}
};
class BadFile: public runtime_error {
public:
BadFile (fs::path const &p): runtime_error(p.native()) {}
};
class DataCorruption: public runtime_error {
public:
DataCorruption (): runtime_error("picpac data corruption") {}
};
class BadRecordSize: public runtime_error {
public:
BadRecordSize (uintmax_t sz): runtime_error(lexical_cast<string>(sz)) {}
};
/// Meta data of a record
struct __attribute__((__packed__)) Meta {
struct __attribute__((__packed__)) Field { // 8 bytes
/// Field size
uint32_t size;
/// Field type. See FieldType.
uint8_t type;
uint8_t reserved1;
uint16_t reserved2;
};
// total 16 bytes
/// For storing user ID. PicPac does not use the ID field.
uint32_t id;
/// Label of record -- if it is relevant.
/** Label can be an integer representing category ID,
* or a float number for regression. In the previous case,
* number of category must not exceed system limitation.
*/
float label;
/// Number of fields in the record.
uint8_t width;
uint8_t reserved1;
int16_t label2; // optional secondary label, for stratification
uint32_t reserved2;
/// Meta data of fields.
std::array<Field, MAX_FIELDS> fields;
void init () {
char *begin = reinterpret_cast<char *>(this);
std::fill(begin, begin + sizeof(*this), 0);
}
// only copies user info
void copy (Meta const &m) {
id = m.id;
label = m.label;
reserved1 = m.reserved1;
label2 = m.label2;
reserved2 = m.reserved2;
}
};
static_assert(sizeof(Meta) == 64, "bad Meta size");
/// Data Record, non-copiable but movable
class Record { // record owns the data
// All field data are stored in raw/on-disk format in the data field
string data; // raw data
// and meta_ptr and field_ptrs are used to access the data
Meta *meta_ptr; // pointer into data
array<char *, MAX_FIELDS> field_ptrs; // pointers into data
void alloc_helper (int nf, uintmax_t off) {
if (!(off < MAX_RECORD_SIZE)) {
throw BadRecordSize(off);
}
data.resize(off);
meta_ptr = reinterpret_cast<Meta *>(&data[0]);
meta_ptr->init();
meta_ptr->width = nf;
}
template <typename ... Args>
void alloc_helper (int ifld, uintmax_t off, uintmax_t size, Args... args) {
alloc_helper(ifld + 1, off + size, args...);
meta_ptr->fields[ifld].size = size;
field_ptrs[ifld] = &data[off];
}
template <typename ... Args>
void alloc (double label, Args... args) {
alloc_helper(0, sizeof(Meta), args...);
meta_ptr->label = label;
}
public:
Record (const Record&) = delete;
Record& operator=(const Record&) = delete;
void swap (Record &r) {
data.swap(r.data);
std::swap(meta_ptr, r.meta_ptr);
std::swap(field_ptrs, r.field_ptrs);
}
Record (Record &&r) {
swap(r);
}
Record& operator=(Record &&r) {
swap(r);
return *this;
}
ssize_t write (int fd, bool compact) const;
ssize_t read (int fd, off_t off, size_t size);
/// Construct an empty record, for future read from disk.
Record () {}
/// Construct a record with file content.
Record (float label, fs::path const &file);
/// Construct a record with file content and extra string.
Record (float label, fs::path const &file, string const &extra);
/// Construct a record with file content and extra string.
Record (float label, fs::path const &file, fs::path const &file2);
/// Construct a record with file content and extra string.
Record (float label, string const &data);
/// Construct a record with file content and extra string.
Record (float label, string const &data, string const &extra);
Meta &meta () { return *meta_ptr; }
Meta const &meta () const { return *meta_ptr; }
/// Return number fields.
unsigned size () const { return meta_ptr->width; }
// replace existing field to buf
// if type >= 0, type value is also replaced
// because a record is compactly stored, replacing a field
// needs to reallocate the storage and copy the existing data over
void replace (unsigned f, string const &buf, int type = -1);
/// Get field buffer.
const_buffer field (unsigned f) const {
CHECK(f < meta_ptr->width);
return const_buffer(field_ptrs[f], meta_ptr->fields[f].size);
}
string field_string (unsigned f) const {
if (f < meta_ptr->width) {
return string(field_ptrs[f], field_ptrs[f] + meta_ptr->fields[f].size);
}
return string();
}
/// Get field type.
FieldType fieldType (unsigned f) const {
CHECK(f < meta_ptr->width);
return FieldType(meta_ptr->fields[f].type);
}
};
struct __attribute__((__packed__)) SegmentHeader {
static uint16_t constexpr MAGIC = 0x59AC;
static uint16_t constexpr VERSION = 0x0100; // major 01, minor 00
uint16_t magic;
uint16_t version;
uint16_t size; // number of records
uint16_t reserved1;
uint64_t link; // next segment offset
// -- 16 bytes so far
uint64_t reserved2;
uint64_t reserved3;
// -- 8160 bytes below
array<uint32_t, MAX_SEG_RECORDS> sizes;
array<float, MAX_SEG_RECORDS> groups;
void init () {
char *begin = reinterpret_cast<char *>(this);
std::fill(begin, begin + sizeof(*this), 0);
magic = MAGIC;
version = VERSION;
}
};
static_assert(sizeof(SegmentHeader) == 8192, "Bad segment header size");
static_assert(sizeof(SegmentHeader) % RECORD_ALIGN == 0, "Bad segment header size");
class FileWriter {
int fd;
int flags;
off_t seg_off; // last segment offset
SegmentHeader seg; // segment header
unsigned next; // next record offset within segment
// initialize a new segment at the end of file
// and set status to the new segment
void open_segment ();
// write the meta data of last segment to file
void close_segment ();
bool compact () const {
return !!(flags & COMPACT);
}
public:
enum {
INDEX_LABEL2 = 1, // use "label" by default
// this makes segment header store label2
COMPACT = 2
};
FileWriter (fs::path const &path, int flags_ = 0);
~FileWriter ();
void append (Record const &r);
};
struct Locator {
off_t offset;
uint32_t size;
float group;
uint32_t serial;
// A stream might have multiple underlying files
// and these are distinguished with file
uint32_t file;
};
/// A callback function that reads the record.
typedef std::function<void(Record *)> RecordReader;
class FileReader {
int fd;
public:
FileReader (fs::path const &path);
~FileReader ();
void ping (vector<Locator> *l, uint32_t file = 0);
void read (Locator const &l, Record *r) {
ssize_t sz = r->read(fd, l.offset, l.size);
CHECK(sz == l.size);
}
/*
RecordReader reader (Locator l) {
return [this, l](Record *r){read(l, r);};
}
*/
};
class IndexedFileReader: public FileReader {
vector<Locator> index;
public:
IndexedFileReader (fs::path const &path)
: FileReader(path) {
ping(&index);
}
size_t size () const { return index.size(); }
float group (size_t i) const {
if (!(i < index.size())) throw std::out_of_range("");
return index[i].group;
}
void read (size_t i, Record *r) {
if (!(i < index.size())) throw std::out_of_range("");
FileReader::read(index[i], r);
}
Locator const &locator (size_t i) const {
return index[i];
}
void loopIndex (std::function<void(Locator const &)> cb) {
for (auto const &l: index) {
cb(l);
}
}
void loop (std::function<void(Record &)> cb) {
Record rec;
for (auto const &l: index) {
FileReader::read(l, &rec);
cb(rec);
}
}
};
/// End of Stream exception, thrown when no more data is loaded
struct EoS {
};
class Stream: public FileReader {
public:
struct Config {
int seed; // random seed
bool loop;
bool shuffle;
bool reshuffle;
int stratify;
unsigned split;
vector<unsigned> split_keys;
int split_fold;
bool split_negate;
string mixin;
float mixin_group_delta;
unsigned mixin_max;
Config()
: seed(DEFAULT_SEED),
loop(true),
shuffle(true),
reshuffle(true),
stratify(true),
split(1),
split_fold(0),
split_negate(false),
mixin_group_delta(0),
mixin_max(0)
{
}
/// Initialize split scheme for K-fold cross validation.
/**
* if train:
* use K-1 splits other than fold
*
* if not train:
* use 1 split specified by fold
*/
void kfold ();
};
protected:
Config config;
std::default_random_engine rng;
private:
vector<FileReader *> readers;
struct Group {
unsigned id; // unique group ID
vector<Locator> index;
unsigned next;
};
vector<Group> groups;
vector<unsigned> group_index;
unsigned next_group;
unsigned sz_total;
unsigned sz_used;
unsigned ncat;
unsigned ngroup;
public:
Stream (fs::path const &, Config const &);
~Stream ();
unsigned categories () const {
return ncat;
}
void reset () {
group_index.clear();
for (unsigned i = 0; i < groups.size(); ++i) {
groups[i].next = 0;
group_index.push_back(i);
}
next_group = 0;
}
Locator next ();
RecordReader reader (Locator l) {
return [this, l](Record *r){this->readers[l.file]->read(l, r);};
}
void read_next (Record *r) {
Locator l = next();
CHECK(l.file >= 0 && l.file < readers.size());
readers[l.file]->read(l, r);
}
// return total records in the file
unsigned total () const {
return sz_total;
}
unsigned size () const {
return sz_used;
}
};
/// Dummy loader that directly returns the record itself.
/** All loaders must follow exactly the same interface as DummyLoader:
* - define Config, Value and PerturbVector.
* - implement sample and load
*/
class DummyLoader {
public:
struct Config {
};
typedef Record Value;
struct CacheValue {
};
struct PerturbVector {
};
DummyLoader (Config const &) {}
/// Sample a perturb vector.
/** This is guaranteed to be run in serial. */
template <typename RNG>
void sample (RNG &, PerturbVector *) {
}
/// Convert a record into value.
/** This might be invoked in parallel and should be deterministic.
* All randomization should be done in sample. */
void load (RecordReader rr, PerturbVector const &, Value *out,
CacheValue *, std::mutex *) const {
Record r;
rr(&r);
*out = std::move(r);
}
};
/// Stream with prefetch and transformation.
/**
* This stream does parallel prefetching and transformation.
* To plugin in a transformation, parameterize this class with a
* Loader class. This stream preserves the order of underlying stream
* for reproducibility. All randomization are done in serial.
*/
template <typename Loader = DummyLoader> // transform class to serve as base class
class PrefetchStream: public Stream, public Loader {
public:
typedef typename Loader::Value Value;
typedef typename Loader::CacheValue CacheValue;
typedef typename Loader::PerturbVector PerturbVector;
private:
struct Task { // prefetch task
enum Status {
EMPTY = 0,
PENDING,
LOADING,
LOADED
} status;
// Task state transform graph:
// empty -> pending -> loading -> loaded -> empty
Locator locator;
PerturbVector perturb;
Value value;
Task (): status(EMPTY) {
}
};
bool started;
unsigned nth; // # threads
bool eos; // eos signal from upstream
int inqueue; // pending + loaded
// About peek & prefetch:
// The basic work flow is
// - add a prefetch task
// - consume a record
// So that # tasks in queue is always the same.
// When peeking is involved, a prefetch task might be
// added without consuming a record. So the follow becomes:
//
// * Peak
// - if not previous peeked, add a prefetch task
// - return a record without consume it
// * next:
// - if not previous peeked, add a prefetch task
// - remove the peek status
// - consume a record.
// When multiple peeks are done consecutively, they see
// the same record.
bool peeked; // peek has been invoked
vector<Task> queue; // prefetch queue
vector<CacheValue> cache;
unsigned next_loaded;
unsigned next_pending;
unsigned next_empty;
std::condition_variable has_pending;
std::condition_variable has_loaded;
std::mutex mutex;
std::mutex cache_mutex;
vector<std::thread> threads;
Value value_holder;
bool prefetch_unsafe () {
if (eos) return false;
try {
Task &task = queue[next_empty];
CHECK(task.status == Task::EMPTY);
task.locator = Stream::next();
Loader::sample(rng, &task.perturb);
task.status = Task::PENDING;
next_empty = (next_empty + 1) % queue.size();
++inqueue;
has_pending.notify_one();
return true;
}
catch (EoS) {
eos = true;
has_pending.notify_all();
return false;
}
}
void worker () {
for (;;) {
unsigned todo = 0;
{
unique_lock lock(mutex);
while (queue[next_pending].status != Task::PENDING) {
if (eos) return;
has_pending.wait(lock);
}
todo = next_pending;
next_pending = (next_pending + 1) % queue.size();
}
Task &task = queue[todo];
task.status = Task::LOADING;
CacheValue *pc = nullptr;
std::mutex *pm = nullptr;
if (cache.size() > 0) {
CHECK(task.locator.serial < cache.size());
pc = &cache[task.locator.serial];
pm = &cache_mutex;
}
try {
Loader::load(reader(task.locator), task.perturb, &task.value, pc, pm);
}
catch (runtime_error const &e) {
LOG(ERROR) << "runtime_error: " << e.what();
LOG(ERROR) << Stack().format();
throw;
}
catch (...) {
LOG(ERROR) << "unknown_error";
LOG(ERROR) << Stack().format();
throw;
}
task.status = Task::LOADED;
// add memory barrier
has_loaded.notify_one();
}
}
void start () {
CHECK(!started);
started = true;
eos = false;
inqueue = 0;
peeked = false;
next_loaded = next_pending = next_empty = 0;
CHECK(queue.size());
for (auto &v: queue) {
v.status = Task::EMPTY;
}
for (unsigned i = 0; i < queue.size() - 1; ++i) {
// next always enqueues 1 prefetch task before it takes away an item
// need to leave one space for that
if (!prefetch_unsafe()) break;
}
//LOG(INFO) << "Starting " << nth << " threads.";
CHECK(threads.empty());
for (unsigned i = 0; i < nth; ++i) {
threads.emplace_back([this](){this->worker();});
}
}
void stop () {
CHECK(started);
started = false;
eos = true;
has_pending.notify_all();
for (auto &th: threads) {
th.join();
}
threads.clear();
}
public:
struct Config: public Stream::Config, public Loader::Config {
bool cache;
unsigned preload;
unsigned threads; // 0 to use all cores
Config (): cache(true),
preload(DEFAULT_PRELOAD),
threads(0)
{
}
};
PrefetchStream (fs::path const &p, Config const &c)
: Stream(p, c), Loader(c), started(false),
nth(c.threads > 0? c.threads: DEFAULT_THREADS), queue(c.preload+1) {
if (c.cache) {
cache.resize(total());
}
// enqueue tasks
start();
}
~PrefetchStream () {
stop();
}
void reset () {
stop();
Stream::reset();
start();
}
Value &peek () {
unique_lock lock(mutex);
if (!peeked) {
prefetch_unsafe();
peeked = true;
}
Task &next = queue[next_loaded];
while (next.status != Task::LOADED) {
if (inqueue == 0) {
throw EoS();
}
has_loaded.wait(lock);
}
return next.value;
}
Value &&next () {
unique_lock lock(mutex);
if (peeked) {
peeked = false;
}
else {
prefetch_unsafe();
}
Task &next = queue[next_loaded];
while (next.status != Task::LOADED) {
if (inqueue == 0) {
throw EoS();
}
has_loaded.wait(lock);
}
value_holder = std::move(next.value);
next.status = Task::EMPTY;
--inqueue;
next_loaded = (next_loaded + 1) % queue.size();
return std::move(value_holder);
}
};
}