Skip to content

Commit

Permalink
Added options to specify dimension ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
nightduck committed Jan 12, 2021
1 parent f6884f0 commit 04a5823
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
34 changes: 31 additions & 3 deletions include/tkDNN/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

namespace tk { namespace dnn {

enum dimFormat_t {
CHW,
NCHW,
//NHWC
};

/**
Data representation between layers
n = batch size
Expand All @@ -21,9 +27,31 @@ struct dataDim_t {

dataDim_t() : n(1), c(1), h(1), w(1), l(1) {};

dataDim_t(nvinfer1::Dims &d) :
n(1), c(d.d[0] ? d.d[0] : 1), h(d.d[1] ? d.d[1] : 1),
w(d.d[2] ? d.d[2] : 1), l(d.d[3] ? d.d[3] : 1) {};
dataDim_t(nvinfer1::Dims &d, dimFormat_t df) {
switch(df) {
case CHW:
n=1;
c = d.d[0] ? d.d[0] : 1;
h = d.d[1] ? d.d[1] : 1;
w = d.d[2] ? d.d[2] : 1;
l = d.d[3] ? d.d[3] : 1;
break;
case NCHW:
n = d.d[0] ? d.d[0] : 1;
c = d.d[1] ? d.d[1] : 1;
h = d.d[2] ? d.d[2] : 1;
w = d.d[3] ? d.d[3] : 1;
l = d.d[4] ? d.d[4] : 1;
break;
// case NHWC:
// n = d.d[0] ? d.d[0] : 1;
// h = d.d[1] ? d.d[1] : 1;
// w = d.d[2] ? d.d[2] : 1;
// c = d.d[3] ? d.d[3] : 1;
// l = d.d[4] ? d.d[4] : 1;
// break;
}
};

dataDim_t(int _n, int _c, int _h, int _w, int _l = 1) :
n(_n), c(_c), h(_h), w(_w), l(_l) {};
Expand Down
2 changes: 1 addition & 1 deletion include/tkDNN/NetworkRT.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class NetworkRT {

PluginFactory *pluginFactory;

NetworkRT(Network *net, const char *name, const char *input_name="data", const char *output_name="out");
NetworkRT(Network *net, const char *name, dimFormat_t dim_format=CHW, const char *input_name="data", const char *output_name="out");
virtual ~NetworkRT();

int getMaxBatchSize() {
Expand Down
8 changes: 4 additions & 4 deletions src/NetworkRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace tk { namespace dnn {

std::map<Layer*, nvinfer1::ITensor*>tensors;

NetworkRT::NetworkRT(Network *net, const char *name, const char *input_name, const char *output_name) {
NetworkRT::NetworkRT(Network *net, const char *name, dimFormat_t dim_format, const char *input_name, const char *output_name) {

float rt_ver = float(NV_TENSORRT_MAJOR) +
float(NV_TENSORRT_MINOR)/10 +
Expand Down Expand Up @@ -167,17 +167,17 @@ NetworkRT::NetworkRT(Network *net, const char *name, const char *input_name, con


Dims iDim = engineRT->getBindingDimensions(buf_input_idx);
input_dim = dataDim_t(iDim);
input_dim = dataDim_t(iDim, dim_format);
input_dim.print();

Dims oDim = engineRT->getBindingDimensions(buf_output_idx);
output_dim = dataDim_t(oDim);
output_dim = dataDim_t(oDim, dim_format);
output_dim.print();

// create GPU buffers and a stream
for(int i=0; i<engineRT->getNbBindings(); i++) {
Dims dim = engineRT->getBindingDimensions(i);
buffersDIM[i] = dataDim_t(dim);
buffersDIM[i] = dataDim_t(dim, dim_format);
std::cout<<"RtBuffer "<<i<<" dim: "; buffersDIM[i].print();
checkCuda(cudaMalloc(&buffersRT[i], engineRT->getMaxBatchSize()*buffersDIM[i].tot()*sizeof(dnnType)));
}
Expand Down

0 comments on commit 04a5823

Please sign in to comment.