-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTrtBuffer.h
81 lines (55 loc) · 1.93 KB
/
TrtBuffer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#ifndef __TRT_BUFFER_H__
#define __TRT_BUFFER_H__
#include "utils.h"
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include <cuda_runtime.h>
#include "spdlog/spdlog.h"
#include "Int8EntropyCalibrator.h"
#include "concurrentqueue/blockingconcurrentqueue.h"
#include <string>
#include <vector>
#include <iostream>
#include <algorithm>
class TrtBuffer {
public:
TrtBuffer() {};
~TrtBuffer();
TrtBuffer(const std::string &engineFile);
TrtBuffer(const std::string &engineFile, int profileIndex, nvinfer1::Dims maxDims);
bool DeserializeEngine(const std::string &engineFile);
// 模型初始化
void InitEngine();
// 模型初始化 动态
void InitEngine(int profileIndex, nvinfer1::Dims maxDims);
void DataTransferAsync(int size, int bindIndex, bool isHostToDevice);
void ForwardAsync(); // 固定维度
void ForwardAsync(nvinfer1::Dims &dim);
void GetOutput();
size_t GetRuntimeBindingSize(int bindIndex) const {
return volume(GetRuntimeBindingDims(bindIndex));
}
nvinfer1::Dims GetRuntimeBindingDims(int bindIndex) const {
return context->getBindingDimensions(bindIndex);
}
nvinfer1::DataType GetBindingDataType(int bindIndex) const {
return bindingDataType[bindIndex];
}
void StreamSynchronize() const {
cudaStreamSynchronize(stream);
}
// var
cudaStream_t stream;
nvinfer1::IExecutionContext *context = nullptr;
std::vector<void *> bindingDevice; // 设备内存
std::vector<float *> bindingHost; // 锁页内存
std::vector<int> inputBindIndex; // 输入index
std::vector<int> outputBindIndex; // 输出index
std::vector<int> bindingSize; // 输入输入最大大小
std::vector <nvinfer1::Dims> bindingDims;
std::vector <nvinfer1::DataType> bindingDataType;
private:
TrtLogger mLogger;
nvinfer1::ICudaEngine *_engine = nullptr;
};
#endif // !__TRT_BUFFER_H__