Skip to content

Commit

Permalink
add trt python demo
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Jul 4, 2022
1 parent 4dcf170 commit c2d90c4
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 19 deletions.
74 changes: 55 additions & 19 deletions tensorrt/README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@

### My platform

* ubuntu 18.04
* nvidia Tesla T4 gpu, driver newer than 440
* cuda 10.2, cudnn 8
* cmake 3.10.2
* opencv built from source
* tensorrt 7.2.3.4
## Deploy with Tensorrt


### Export model to onnx
I export the model like this:
Firstly, We should export our trained model to onnx model:
```
$ cd BiSeNet/
$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx
```

**NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here.
**NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here.

Then we can use either c++ or python to compile the model and run inference.


### Using C++

#### My platform

### Build with source code
* ubuntu 18.04
* nvidia Tesla T4 gpu, driver newer than 450.80
* cuda 11.3, cudnn 8
* cmake 3.17.1
* opencv built from source
* tensorrt 8.2.5.1



#### Build with source code
Just use the standard cmake build method:
```
mkdir -p tensorrt/build
Expand All @@ -28,7 +37,7 @@ make
This would generate a `./segment` in the `tensorrt/build` directory.


### Convert onnx to tensorrt model
#### Convert onnx to tensorrt model
If you can successfully compile the source code, you can parse the onnx model to tensorrt model like this:
```
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt
Expand All @@ -37,30 +46,57 @@ If your gpu support acceleration with fp16 inferenece, you can add a `--fp16` op
```
$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp16
```
Note that I use the simplest method to parse the command line args, so please do **Not** change the order of the above command.
Note that I use the simplest method to parse the command line args, so please do **Not** change the order of the args in above command.


### Infer with one single image
#### Infer with one single image
Run inference like this:
```
$ ./segment run /path/to/saved_model.trt /path/to/input/image.jpg /path/to/saved_img.jpg
```

### Test speed

#### Test speed
The speed depends on the specific gpu platform you are working on, you can test the fps on your gpu like this:
```
$ ./segment test /path/to/saved_model.trt
```


## Tips:
#### Tips:
1. ~Since tensorrt 7.0.0 cannot parse well the `bilinear interpolation` op exported from pytorch, I replace them with pytorch `nn.PixelShuffle`, which would bring some performance overhead(more flops and parameters), and make inference a bit slower. Also due to the `nn.PixelShuffle` op, you **must** export the onnx model with input size to be *n* times of 32.~
If you are using 7.2.3.4, you should not have problem with `interpolate` anymore.
If you are using 7.2.3.4 or newer versions, you should not have problem with `interpolate` anymore.

2. ~There would be some problem for tensorrt 7.0.0 to parse the `nn.AvgPool2d` op from pytorch with onnx opset11. So I use opset10 to export the model.~
Likewise, you do not need to worry about this anymore with 7.2.3.4.
Likewise, you do not need to worry about this anymore with version newer than 7.2.3.4.

3. The speed(fps) is tested on a single nvidia Tesla T4 gpu with `batchsize=1` and `cropsize=(1024,2048)`. Please note that T4 gpu is almost 2 times slower than 2080ti, you should evaluate the speed considering your own platform and cropsize. Also note that the performance would be affected if your gpu is concurrently working on other tasks. Please make sure no other program is running on your gpu when you test the speed.

4. On my platform, after compiling with tensorrt, the model size of bisenetv1 is 29Mb(fp16) and 128Mb(fp32), and the size of bisenetv2 is 16Mb(fp16) and 42Mb(fp32). However, the fps of bisenetv1 is 68(fp16) and 23(fp32), while the fps of bisenetv2 is 59(fp16) and 21(fp32). It is obvious that bisenetv2 has fewer parameters than bisenetv1, but the speed is otherwise. I am not sure whether it is because tensorrt has worse optimization strategy in some ops used in bisenetv2(such as depthwise convolution) or because of the limitation of the gpu on different ops. Please tell me if you have better idea on this.


### Using python

You can also use python script to compile and run inference of your model.


#### Compile model to onnx

With this command:
```
$ cd BiSeNet/tensorrt
$ python segment.py compile --onnx /path/to/model.onnx --savepth ./model.trt --quant fp16/fp32
```

This will compile onnx model into tensorrt serialized engine, save save to `./model.trt`.


#### inference with Tensorrt

Run Inference like this:
```
$ python segment.py run --mdpth ./model.trt --impth ../example.png --outpth ./res.png
```

This will use the tensorrt model compiled above, and run inference with the example image.

157 changes: 157 additions & 0 deletions tensorrt/segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@

import os
import os.path as osp
import cv2
import numpy as np
import logging
import argparse

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit


parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
compile_parser = subparsers.add_parser('compile')
compile_parser.add_argument('--onnx')
compile_parser.add_argument('--quant', default='fp32')
compile_parser.add_argument('--savepth', default='./model.trt')
run_parser = subparsers.add_parser('run')
run_parser.add_argument('--mdpth')
run_parser.add_argument('--impth')
run_parser.add_argument('--outpth', default='./res.png')
args = parser.parse_args()


np.random.seed(123)
in_datatype = trt.nptype(trt.float32)
out_datatype = trt.nptype(trt.int32)
palette = np.random.randint(0, 256, (256, 3)).astype(np.uint8)

ctx = pycuda.autoinit.context
trt.init_libnvinfer_plugins(None, "")
TRT_LOGGER = trt.Logger()



def get_image(impth, size):
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)[:, None, None]
var = np.array([0.229, 0.224, 0.225], dtype=np.float32)[:, None, None]
iH, iW = size[0], size[1]
img = cv2.imread(impth)[:, :, ::-1]
orgH, orgW, _ = img.shape
img = cv2.resize(img, (iW, iH)).astype(np.float32)
img = img.transpose(2, 0, 1) / 255.
img = (img - mean) / var
return img, (orgH, orgW)



def allocate_buffers(engine):
h_input = cuda.pagelocked_empty(
trt.volume(engine.get_binding_shape(0)), dtype=in_datatype)
print(engine.get_binding_shape(0))
d_input = cuda.mem_alloc(h_input.nbytes)
h_outputs, d_outputs = [], []
n_outs = 1
for i in range(n_outs):
h_output = cuda.pagelocked_empty(
trt.volume(engine.get_binding_shape(i+1)),
dtype=out_datatype)
d_output = cuda.mem_alloc(h_output.nbytes)
h_outputs.append(h_output)
d_outputs.append(d_output)
stream = cuda.Stream()
return (
stream,
h_input,
d_input,
h_outputs,
d_outputs,
)


def build_engine_from_onnx(onnx_file_path):
engine = None ## add this to avoid return deleted engine
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser, trt.Runtime(TRT_LOGGER) as runtime:

# Parse model file
print(f'Loading ONNX file from path {onnx_file_path}...')
assert os.path.exists(onnx_file_path), f'cannot find {onnx_file_path}'
with open(onnx_file_path, 'rb') as fr:
if not parser.parse(fr.read()):
print ('ERROR: Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print (parser.get_error(error))
assert False

# build settings
builder.max_batch_size = 128
config.max_workspace_size = 1 << 30 # 1G
if args.quant == 'fp16':
config.set_flag(trt.BuilderFlag.FP16)

print("Start to build Engine")
plan = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(plan)
return engine


def serialize_engine_to_file(engine, savepth):
plan = engine.serialize()
with open(savepth, "wb") as fw:
fw.write(plan)


def deserialize_engine_from_file(savepth):
with open(savepth, 'rb') as fr, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(fr.read())
return engine


def main():
if args.command == 'compile':
engine = build_engine_from_onnx(args.onnx)
serialize_engine_to_file(engine, args.savepth)

elif args.command == 'run':
engine = deserialize_engine_from_file(args.mdpth)

ishape = engine.get_binding_shape(0)
img, (orgH, orgW) = get_image(args.impth, ishape[2:])

## create engine and allocate bffers
(
stream,
h_input,
d_input,
h_outputs,
d_outputs,
) = allocate_buffers(engine)
ctx.push()
context = engine.create_execution_context()
ctx.pop()
bds = [int(d_input), ] + [int(el) for el in d_outputs]

h_input = np.ascontiguousarray(img)
cuda.memcpy_htod_async(d_input, h_input, stream)
context.execute_async(
bindings=bds, stream_handle=stream.handle)
for h_output, d_output in zip(h_outputs, d_outputs):
cuda.memcpy_dtoh_async(h_output, d_output, stream)
stream.synchronize()

out = palette[h_outputs[0]]
outshape = engine.get_binding_shape(1)
H, W = outshape[1], outshape[2]
out = out.reshape(H, W, 3)
out = cv2.resize(out, (orgW, orgH))
cv2.imwrite(args.outpth, out)



if __name__ == '__main__':
main()

0 comments on commit c2d90c4

Please sign in to comment.