forked from aaalgo/picpac
-
Notifications
You must be signed in to change notification settings - Fork 0
/
python-api.cpp
175 lines (158 loc) · 5.99 KB
/
python-api.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#include <fstream>
#include <boost/ref.hpp>
#include <boost/python.hpp>
#include <boost/python/make_constructor.hpp>
#include <boost/python/raw_function.hpp>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/ndarrayobject.h>
#include "picpac.h"
#include "picpac-cv.h"
using namespace boost::python;
using namespace picpac;
namespace {
template <typename T>
T *get_ndarray_data (object &o) {
PyArrayObject *nd = reinterpret_cast<PyArrayObject *>(o.ptr());
return reinterpret_cast<T*>(PyArray_DATA(nd));
}
class NumpyBatchImageStream: public BatchImageStream {
public:
NumpyBatchImageStream (std::string const &path, Config const &c)
: BatchImageStream(fs::path(path), c) {
import_array();
}
tuple next () {
vector<npy_intp> images_dims;
vector<npy_intp> labels_dims;
next_shape(&images_dims, &labels_dims);
object images = object(boost::python::handle<>(PyArray_SimpleNew(images_dims.size(), &images_dims[0], NPY_FLOAT)));
CHECK(images.ptr());
float *images_buf = get_ndarray_data<float>(images);
object labels = object(boost::python::handle<>(PyArray_SimpleNew(labels_dims.size(), &labels_dims[0], NPY_FLOAT)));
CHECK(labels.ptr());
float *labels_buf = get_ndarray_data<float>(labels);
unsigned padding;
next_fill(images_buf, labels_buf, &padding);
return make_tuple(images, labels, padding);
}
};
object create_image_stream (tuple args, dict kwargs) {
object self = args[0];
CHECK(len(args) > 1);
string path = extract<string>(args[1]);
NumpyBatchImageStream::Config config;
/*
bool train = extract<bool>(kwargs.get("train", true));
unsigned K = extract<unsigned>(kwargs.get("K", 1));
unsigned fold = extract<unsigned>(kwargs.get("fold", 0));
if (K <= 1) {
if (!train) {
config.loop = false;
config.reshuffle = false;
}
}
else {
config.kfold(K, fold, train);
}
*/
#define PICPAC_CONFIG_UPDATE(C, P) \
C.P = extract<decltype(C.P)>(kwargs.get(#P, C.P))
PICPAC_CONFIG_UPDATE_ALL(config);
#undef PICPAC_CONFIG_UPDATE
return self.attr("__init__")(path, config);
};
object return_iterator (tuple args, dict kwargs) {
object self = args[0];
self.attr("reset")();
return self;
};
class Writer: public FileWriter {
public:
Writer (string const &path): FileWriter(fs::path(path), FileWriter::COMPACT) {
}
void append (float label, string const &buf) {
Record record(label, buf);
FileWriter::append(record);
}
void append (string const &buf1, string const &buf2) {
Record record(0, buf1, buf2);
FileWriter::append(record);
}
};
void serialize_raw_ndarray (object &obj, std::ostream &os) {
PyArrayObject *image = reinterpret_cast<PyArrayObject *>(obj.ptr());
int nd = PyArray_NDIM(image);
CHECK(nd == 2 || nd == 3);
auto desc = PyArray_DESCR(image);
CHECK(desc);
CHECK(PyArray_EquivByteorders(desc->byteorder, NPY_NATIVE)
|| desc->byteorder == '|') << "Only support native/little endian";
int elemSize = desc->elsize;
CHECK(elemSize > 0) << "Flex type not supported.";
int ch = (nd == 2) ? 1 : PyArray_DIM(image, 2);
elemSize *= ch; // opencv elements includes all channels
//CHECK(image->strides[1] == elemSize) << "Image cols must be consecutive";
int rows = PyArray_DIM(image, 0);
int cols = PyArray_DIM(image, 1);
int t = PyArray_TYPE(image);
int type = 0;
switch (t) {
case NPY_UINT8: type = CV_MAKETYPE(CV_8U, ch); break;
case NPY_INT8: type = CV_MAKETYPE(CV_8S, ch); break;
case NPY_UINT16: type = CV_MAKETYPE(CV_16U, ch); break;
case NPY_INT16: type = CV_MAKETYPE(CV_16S, ch); break;
case NPY_INT32: type = CV_MAKETYPE(CV_32S, ch); break;
case NPY_FLOAT32: type = CV_MAKETYPE(CV_32F, ch); break;
case NPY_FLOAT64: type = CV_MAKETYPE(CV_64F, ch); break;
default: CHECK(0) << "type not supported: " << t;
}
int stride = PyArray_STRIDE(image, 0);
CHECK(stride == cols * elemSize) << "bad stride";
os.write(reinterpret_cast<char const *>(&type), sizeof(type));
os.write(reinterpret_cast<char const *>(&rows), sizeof(rows));
os.write(reinterpret_cast<char const *>(&cols), sizeof(cols));
os.write(reinterpret_cast<char const *>(&elemSize), sizeof(elemSize));
char const *off = PyArray_BYTES(image);
for (int i = 0; i < rows; ++i) {
os.write(off, cols * elemSize);
off += stride;
}
}
string encode_raw_ndarray (object &obj) {
std::ostringstream ss;
serialize_raw_ndarray(obj, ss);
return ss.str();
}
void write_raw_ndarray (string const &path, object &obj) {
std::ofstream os(path.c_str(), std::ios::binary);
serialize_raw_ndarray(obj, os);
}
void (Writer::*append1) (float, string const &) = &Writer::append;
void (Writer::*append2) (string const &, string const &) = &Writer::append;
void translate_eos (EoS const &)
{
// Use the Python 'C' API to set up an exception object
PyErr_SetNone(PyExc_StopIteration);
}
}
BOOST_PYTHON_MODULE(_picpac)
{
scope().attr("__doc__") = "PicPoc Python API";
register_exception_translator<EoS>(&translate_eos);
class_<NumpyBatchImageStream::Config>("ImageStreamParams", init<>());
class_<NumpyBatchImageStream, boost::noncopyable>("ImageStream", no_init)
.def("__init__", raw_function(create_image_stream), "exposed ctor")
.def("__iter__", raw_function(return_iterator))
.def(init<string, NumpyBatchImageStream::Config const&>()) // C++ constructor not exposed
.def("next", &NumpyBatchImageStream::next)
.def("size", &NumpyBatchImageStream::size)
.def("reset", &NumpyBatchImageStream::reset)
.def("categories", &NumpyBatchImageStream::categories)
;
class_<Writer>("Writer", init<string>())
.def("append", append1)
.def("append", append2)
;
def("encode_raw", ::encode_raw_ndarray);
def("write_raw", ::write_raw_ndarray);
}