Skip to content

Commit

Permalink
feat: use DTO for NCNN init parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
sileht committed Feb 12, 2021
1 parent 00036b1 commit 635e2c1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 46 deletions.
56 changes: 17 additions & 39 deletions src/backends/ncnn/ncnnlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "outputconnectorstrategy.h"
#include <thread>
#include <algorithm>
#include "utils/utils.hpp"

// NCNN
#include "ncnnlib.h"
Expand Down Expand Up @@ -53,10 +52,10 @@ namespace dd
{
this->_libname = "ncnn";
_net = new ncnn::Net();
_net->opt.num_threads = _threads;
_net->opt.num_threads = 1;
_net->opt.blob_allocator = &_blob_pool_allocator;
_net->opt.workspace_allocator = &_workspace_pool_allocator;
_net->opt.lightmode = _lightmode;
_net->opt.lightmode = true;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand All @@ -69,12 +68,9 @@ namespace dd
this->_libname = "ncnn";
_net = tl._net;
tl._net = nullptr;
_nclasses = tl._nclasses;
_threads = tl._threads;
_timeserie = tl._timeserie;
_old_height = tl._old_height;
_inputBlob = tl._inputBlob;
_outputBlob = tl._outputBlob;
_init_dto = tl._init_dto;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand All @@ -94,6 +90,8 @@ namespace dd
void NCNNLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::init_mllib(const APIData &ad)
{
_init_dto = ad.createSharedDTO<NcnnInitDto>();

bool use_fp32 = (ad.has("datatype")
&& ad.get("datatype").get<std::string>()
== "fp32"); // default is fp16
Expand Down Expand Up @@ -124,35 +122,11 @@ namespace dd
_old_height = this->_inputc.height();
_net->set_input_h(_old_height);

if (ad.has("nclasses"))
_nclasses = ad.get("nclasses").get<int>();

if (ad.has("threads"))
_threads = ad.get("threads").get<int>();
else
_threads = dd_utils::my_hardware_concurrency();

_timeserie = this->_inputc._timeserie;
if (_timeserie)
this->_mltype = "timeserie";

if (ad.has("lightmode"))
{
_lightmode = ad.get("lightmode").get<bool>();
_net->opt.lightmode = _lightmode;
}

// setting the value of Input Layer
if (ad.has("inputblob"))
{
_inputBlob = ad.get("inputblob").get<std::string>();
}
// setting the final Output Layer
if (ad.has("outputblob"))
{
_outputBlob = ad.get("outputblob").get<std::string>();
}

_net->opt.lightmode = _init_dto->lightmode;
_blob_pool_allocator.set_size_compare_ratio(0.0f);
_workspace_pool_allocator.set_size_compare_ratio(0.5f);
model_type(this->_mlmodel._params, this->_mltype);
Expand Down Expand Up @@ -233,7 +207,10 @@ namespace dd

// Extract detection or classification
int ret = 0;
std::string out_blob = _outputBlob;
std::string out_blob;
if (_init_dto->outputBlob)
out_blob = _init_dto->outputBlob->std_str();

if (out_blob.empty())
{
if (bbox == true)
Expand Down Expand Up @@ -262,11 +239,11 @@ namespace dd
{
best = ad_output.get("best").get<int>();
}
if (best == -1 || best > _nclasses)
best = _nclasses;
if (best == -1 || best > _init_dto->nclasses)
best = _init_dto->nclasses;

// for loop around batch size
#pragma omp parallel for num_threads(_threads)
#pragma omp parallel for num_threads(*_init_dto->threads)
for (size_t b = 0; b < inputc._ids.size(); b++)
{
std::vector<double> probs;
Expand All @@ -276,8 +253,8 @@ namespace dd
APIData rad;

ncnn::Extractor ex = _net->create_extractor();
ex.set_num_threads(_threads);
ex.input(_inputBlob.c_str(), inputc._in.at(b));
ex.set_num_threads(_init_dto->threads);
ex.input(_init_dto->inputBlob->c_str(), inputc._in.at(b));

ret = ex.extract(out_blob.c_str(), inputc._out.at(b));
if (ret == -1)
Expand Down Expand Up @@ -423,7 +400,8 @@ namespace dd
} // end for batch_size

tout.add_results(vrad);
out.add("nclasses", this->_nclasses);
int nclasses = this->_init_dto->nclasses;
out.add("nclasses", nclasses);
if (bbox == true)
out.add("bbox", true);
out.add("roi", false);
Expand Down
14 changes: 7 additions & 7 deletions src/backends/ncnn/ncnnlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
#ifndef NCNNLIB_H
#define NCNNLIB_H

#include "apidata.h"
#include "utils/utils.hpp"

#include "dto/ncnn.hpp"

// NCNN
#include "net.h"
#include "ncnnmodel.h"

#include "apidata.h"

namespace dd
{
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand All @@ -53,20 +56,17 @@ namespace dd

public:
ncnn::Net *_net = nullptr;
int _nclasses = 0;
bool _timeserie = false;
bool _lightmode = true;

private:
std::shared_ptr<NcnnInitDto> _init_dto;
static ncnn::UnlockedPoolAllocator _blob_pool_allocator;
static ncnn::PoolAllocator _workspace_pool_allocator;

protected:
int _threads = 1;
int _old_height = -1;
std::string _inputBlob = "data";
std::string _outputBlob;
};

}

#endif

0 comments on commit 635e2c1

Please sign in to comment.