Skip to content

Commit

Permalink
[FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder. (#…
Browse files Browse the repository at this point in the history
…640)

* [FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder.
  • Loading branch information
byshiue authored Aug 6, 2020
1 parent 280e75c commit 1aa6813
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,13 @@ class TransformerPlugin: public IPluginV2

bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override
{
return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW;
return type == TransformerTrtTraits<T>::DataType && format == PluginFormat::kNCHW;
}

void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim,
int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override
{
assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW);
assert(dataType == TransformerTrtTraits<T>::DataType && pluginFormat == nvinfer1::PluginFormat::kNCHW);
assert(nInputDim == 2);
assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_);
assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class TRT_Transformer

builder->setMaxBatchSize(batch_size_);
builder->setMaxWorkspaceSize(1 << 20);
builder->setFp16Mode(false);
builder->setFp16Mode(sizeof(T) == 2);

engine_ = builder->buildCudaEngine(*network);
assert(engine_);
Expand Down

0 comments on commit 1aa6813

Please sign in to comment.