Skip to content

Commit

Permalink
update to tensorrt 10.3
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Sep 6, 2024
1 parent f2b9015 commit 81b1179
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 240 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV
mIOUs and fps on cityscapes val set:
| none | ss | ssc | msf | mscf | fps(fp32/fp16/int8) | link |
|------|:--:|:---:|:---:|:----:|:---:|:----:|
| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 25/78/141 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) |
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 26/67/95 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |
| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 112/239/435 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) |
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 103/161/198 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |

mIOUs on cocostuff val2017 set:
| none | ss | ssc | msf | mscf | link |
Expand Down
6 changes: 3 additions & 3 deletions tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ CMAKE_MINIMUM_REQUIRED(VERSION 3.17)

PROJECT(segment)

set(CMAKE_CXX_FLAGS "-std=c++14 -O2")
set(CMAKE_NVCC_FLAGS "-std=c++14 -O2")
set(CMAKE_CXX_FLAGS "-std=c++17 -O2")
set(CMAKE_NVCC_FLAGS "-std=c++20 -O2")


link_directories(/usr/local/cuda/lib64)
Expand All @@ -21,7 +21,7 @@ add_executable(segment segment.cpp trt_dep.cpp read_img.cpp)
target_include_directories(
segment PUBLIC ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS})
target_link_libraries(
segment -lnvinfer -lnvinfer_plugin -lnvparsers -lnvonnxparser -lkernels
segment -lnvinfer -lnvinfer_plugin -lnvonnxparser -lkernels
${CUDA_LIBRARIES}
${OpenCV_LIBRARIES})

35 changes: 15 additions & 20 deletions tensorrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ Then we can use either c++ or python to compile the model and run inference.

#### 1. My platform

* ubuntu 18.04
* nvidia Tesla T4 gpu, driver newer than 450.80
* cuda 11.3, cudnn 8
* cmake 3.22.0
* ubuntu 22.04
* nvidia A40 gpu, driver newer than 555.42.06
* cuda 12.1, cudnn 8
* cmake 3.22.1
* opencv built from source
* tensorrt 8.2.5.1
* tensorrt 10.3.0.26



Expand All @@ -39,14 +39,14 @@ This would generate a `./segment` in the `tensorrt/build` directory.

#### 3. Convert onnx to tensorrt model
If you can successfully compile the source code, you can parse the onnx model to tensorrt model with one of the following commands.
For fp32, command is:
```
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt
```
If your gpu support acceleration with fp16 inferenece, you can add a `--fp16` option to in this step:
For fp32/fp16/bf16, command is:
```
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp32
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp16
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --bf16
```
Make sure that your gpu support acceleration with fp16/bf16 inferenece when you set these options.<br>

Building an int8 engine is also supported. Firstly, you should make sure your gpu support int8 inference, or you model will not be faster than fp16/fp32. Then you should prepare certain amount of images for int8 calibration. In this example, I use train set of cityscapes for calibration. The command is like this:
```
$ rm calibrate_int8 # delete this if exists
Expand All @@ -72,26 +72,21 @@ $ ./segment test /path/to/saved_model.trt


#### 6. Tips:
1. ~Since tensorrt 7.0.0 cannot parse well the `bilinear interpolation` op exported from pytorch, I replace them with pytorch `nn.PixelShuffle`, which would bring some performance overhead(more flops and parameters), and make inference a bit slower. Also due to the `nn.PixelShuffle` op, you **must** export the onnx model with input size to be *n* times of 32.~
If you are using 7.2.3.4 or newer versions, you should not have problem with `interpolate` anymore.

2. ~There would be some problem for tensorrt 7.0.0 to parse the `nn.AvgPool2d` op from pytorch with onnx opset11. So I use opset10 to export the model.~
Likewise, you do not need to worry about this anymore with version newer than 7.2.3.4.
The speed(fps) is tested on a single nvidia A40 gpu with `batchsize=1` and `cropsize=(1024,2048)`, which might be different from your platform and settings. You should evaluate the speed considering your own platform and cropsize. Also note that the performance would be affected if your gpu is concurrently working on other tasks. Please make sure no other program is running on your gpu when you test the speed.

3. The speed(fps) is tested on a single nvidia Tesla T4 gpu with `batchsize=1` and `cropsize=(1024,2048)`. Please note that T4 gpu is almost 2 times slower than 2080ti, you should evaluate the speed considering your own platform and cropsize. Also note that the performance would be affected if your gpu is concurrently working on other tasks. Please make sure no other program is running on your gpu when you test the speed.

4. On my platform, after compiling with tensorrt, the model size of bisenetv1 is 29Mb(fp16) and 128Mb(fp32), and the size of bisenetv2 is 16Mb(fp16) and 42Mb(fp32). However, the fps of bisenetv1 is 68(fp16) and 23(fp32), while the fps of bisenetv2 is 59(fp16) and 21(fp32). It is obvious that bisenetv2 has fewer parameters than bisenetv1, but the speed is otherwise. I am not sure whether it is because tensorrt has worse optimization strategy in some ops used in bisenetv2(such as depthwise convolution) or because of the limitation of the gpu on different ops. Please tell me if you have better idea on this.

5. int8 mode is not always greatly faster than fp16 mode. For example, I tested with bisenetv1-cityscapes and tensorrt 8.2.5.1. With v100 gpu and driver 515.65, the fp16/int8 fps is 185.89/186.85, while with t4 gpu and driver 450.80, it is 78.77/142.31.
### Using python (this is not updated to tensorrt 10.3)

You can also use python script to compile and run inference of your model. <br>

### Using python

You can also use python script to compile and run inference of your model.
Following is still the usage method of tensorrt 8.2.<br>


#### 1. Compile model to onnx


With this command:
```
$ cd BiSeNet/tensorrt
Expand Down
1 change: 0 additions & 1 deletion tensorrt/batch_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class BatchStream : public IBatchStream

void reset(int firstBatch) override
{
cout << "mBatchCount: " << mBatchCount << endl;
mBatchCount = firstBatch;
}

Expand Down
99 changes: 54 additions & 45 deletions tensorrt/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <array>
#include <sstream>
#include <random>
#include <unordered_map>

#include "trt_dep.hpp"
#include "read_img.hpp"
Expand All @@ -27,8 +28,7 @@ using nvinfer1::IBuilderConfig;
using nvinfer1::IRuntime;
using nvinfer1::IExecutionContext;
using nvinfer1::ILogger;
using nvinfer1::Dims3;
using nvinfer1::Dims2;
using nvinfer1::Dims;
using Severity = nvinfer1::ILogger::Severity;

using std::string;
Expand All @@ -39,6 +39,7 @@ using std::vector;
using std::cout;
using std::endl;
using std::array;
using std::stringstream;

using cv::Mat;

Expand All @@ -53,81 +54,84 @@ void test_speed(vector<string> args);


int main(int argc, char* argv[]) {
if (argc < 3) {
cout << "usage is ./segment compile/run/test\n";
std::abort();
}
CHECK (argc >= 3, "usage is ./segment compile/run/test");

vector<string> args;
for (int i{1}; i < argc; ++i) args.emplace_back(argv[i]);

if (args[0] == "compile") {
if (argc < 4) {
cout << "usage is: ./segment compile input.onnx output.trt [--fp16|--fp32]\n";
cout << "or ./segment compile input.onnx output.trt --int8 /path/to/data_root /path/to/ann_file\n";
std::abort();
}
stringstream ss;
ss << "usage is: ./segment compile input.onnx output.trt [--fp16|--fp32|--bf16|--fp8]\n"
<< "or ./segment compile input.onnx output.trt --int8 /path/to/data_root /path/to/ann_file\n";
CHECK (argc >= 5, ss.str());
compile_onnx(args);
} else if (args[0] == "run") {
if (argc < 5) {
cout << "usage is ./segment run ./xxx.trt input.jpg result.jpg\n";
std::abort();
}
CHECK (argc >= 5, "usage is ./segment run ./xxx.trt input.jpg result.jpg");
run_with_trt(args);
} else if (args[0] == "test") {
if (argc < 3) {
cout << "usage is ./segment test ./xxx.trt\n";
std::abort();
}
CHECK (argc >= 3, "usage is ./segment test ./xxx.trt");
test_speed(args);
} else {
CHECK (false, "usage is ./segment compile/run/test");
}

return 0;
}


void compile_onnx(vector<string> args) {

string quant("fp32");
string data_root("none");
string data_file("none");
if ((args.size() >= 4)) {
if (args[3] == "--fp32") {
quant = "fp32";
} else if (args[3] == "--fp16") {
quant = "fp16";
} else if (args[3] == "--int8") {
quant = "int8";
data_root = args[4];
data_file = args[5];
} else {
cout << "invalid args of quantization: " << args[3] << endl;
std::abort();
}
}
int opt_bsize = 1;

std::unordered_map<string, string> quant_map{
{"--fp32", "fp32"},
{"--fp16", "fp16"},
{"--bf16", "bf16"},
{"--fp8", "fp8"},
{"--int8", "int8"},
};
CHECK (quant_map.find(args[3]) != quant_map.end(),
"invalid args of quantization: " + args[3]);
quant = quant_map[args[3]];
if (quant == "int8") {
data_root = args[4];
data_file = args[5];
}

if (args[3] == "--int8") {
if (args.size() > 6) opt_bsize = std::stoi(args[6]);
} else {
if (args.size() > 4) opt_bsize = std::stoi(args[4]);
}

TrtSharedEnginePtr engine = parse_to_engine(args[1], quant, data_root, data_file);
serialize(engine, args[2]);
SemanticSegmentTrt ss_trt;
ss_trt.set_opt_batch_size(opt_bsize);
ss_trt.parse_to_engine(args[1], quant, data_root, data_file);
ss_trt.serialize(args[2]);
}


void run_with_trt(vector<string> args) {

TrtSharedEnginePtr engine = deserialize(args[1]);
SemanticSegmentTrt ss_trt;
ss_trt.deserialize(args[1]);

Dims3 i_dims = static_cast<Dims3&&>(
engine->getBindingDimensions(engine->getBindingIndex("input_image")));
Dims3 o_dims = static_cast<Dims3&&>(
engine->getBindingDimensions(engine->getBindingIndex("preds")));
const int iH{i_dims.d[2]}, iW{i_dims.d[3]};
const int oH{o_dims.d[2]}, oW{o_dims.d[3]};
vector<int> i_dims = ss_trt.get_input_shape();
vector<int> o_dims = ss_trt.get_output_shape();

const int iH{i_dims[2]}, iW{i_dims[3]};
const int oH{o_dims[2]}, oW{o_dims[3]};

// prepare image and resize
vector<float> data; data.resize(iH * iW * 3);
int orgH, orgW;
read_data(args[2], &data[0], iH, iW, orgH, orgW);

// call engine
vector<int> res = infer_with_engine(engine, data);
vector<int> res = ss_trt.inference(data);

// generate colored out
vector<vector<uint8_t>> color_map = get_color_map();
Expand Down Expand Up @@ -166,6 +170,11 @@ vector<vector<uint8_t>> get_color_map() {


void test_speed(vector<string> args) {
TrtSharedEnginePtr engine = deserialize(args[1]);
test_fps_with_engine(engine);
int opt_bsize = 1;
if (args.size() > 2) opt_bsize = std::stoi(args[2]);

SemanticSegmentTrt ss_trt;
ss_trt.set_opt_batch_size(opt_bsize);
ss_trt.deserialize(args[1]);
ss_trt.test_speed_fps();
}
Loading

0 comments on commit 81b1179

Please sign in to comment.