Skip to content

Commit

Permalink
feat(torch): manage models with multiple losses
Browse files Browse the repository at this point in the history
If a model output a dict containing multiple losses they are retrieved and embedded in the training statistics.
Added reg_weight parameter for YOLOX model to control balance between classification and regression losses
  • Loading branch information
Bycob authored and mergify[bot] committed Dec 14, 2021
1 parent 542bcb4 commit bea7cb4
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 74 deletions.
110 changes: 71 additions & 39 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ namespace dd
if (!ad_distort.empty())
{
distort_params._prob = ad_distort.get("prob").get<double>();
this->_logger->info("noise: {}", distort_params._prob);
this->_logger->info("distort: {}", distort_params._prob);
}
inputc._dataset._img_rand_aug_cv = TorchImgRandAugCV(
has_mirror, has_rotate, crop_params, cutout_params,
Expand Down Expand Up @@ -723,17 +723,37 @@ namespace dd
}
}

// net params
int64_t batch_size = 1;
int64_t test_batch_size = 1;

if (ad_mllib.has("net"))
{
APIData ad_net = ad_mllib.getobj("net");
if (ad_net.has("batch_size"))
batch_size = ad_net.get("batch_size").get<int>();
if (ad_net.has("test_batch_size"))
test_batch_size = ad_net.get("test_batch_size").get<int>();
if (ad_net.has("reg_weight"))
{
_reg_weight = ad_net.get("reg_weight").get<double>();
this->_logger->info("reg_weight={}", _reg_weight);
}
}

// solver params
int64_t iterations = 1;
int64_t batch_size = 1;
int64_t iter_size = 1;
int64_t test_batch_size = 1;
int64_t test_interval = 1;
int64_t save_period = 0;

// loss specific to the model
if (_module.has_model_loss())
_loss = _template;

TorchLoss tloss(_loss, _module.has_model_loss(), _seq_training, _timeserie,
_regression, _classification, _segmentation, class_weights,
_module, this->_logger);
_reg_weight, _module, this->_logger);
TorchSolver tsolver(_module, tloss, this->_logger);

// logging parameters
Expand All @@ -758,15 +778,6 @@ namespace dd
save_period = ad_solver.get("snapshot").get<int>();
}

if (ad_mllib.has("net"))
{
APIData ad_net = ad_mllib.getobj("net");
if (ad_net.has("batch_size"))
batch_size = ad_net.get("batch_size").get<int>();
if (ad_net.has("test_batch_size"))
test_batch_size = ad_net.get("test_batch_size").get<int>();
}

bool retain_graph = ad_mllib.has("retain_graph")
? ad_mllib.get("retain_graph").get<bool>()
: false;
Expand Down Expand Up @@ -827,7 +838,7 @@ namespace dd
r.loss = std::make_shared<TorchLoss>(
_loss, r.module->has_model_loss(), _seq_training, _timeserie,
_regression, _classification, _segmentation, class_weights,
*r.module, this->_logger);
_reg_weight, *r.module, this->_logger);
}
}
_module.train();
Expand All @@ -846,6 +857,8 @@ namespace dd
double last_it_time = 0;
double last_test_time = 0;
double train_loss = 0;
std::unordered_map<std::string, double> sub_losses;
double loss_divider = iter_size * gpu_count;
auto data_it = dataloader->begin();

if (data_it == dataloader->end())
Expand Down Expand Up @@ -893,6 +906,7 @@ namespace dd
for (size_t rank = 0; rank < _devices.size(); ++rank)
{
double loss_val = 0;
c10::IValue out_val;
try
{
TorchBatch batch = batches[rank];
Expand Down Expand Up @@ -925,31 +939,13 @@ namespace dd
Tensor y = batch.target.at(0).to(device);

// Prediction
Tensor y_pred;
if (_segmentation)
{
auto out_dict
= rank_module.forward(in_vals).toGenericDict();
y_pred = torch_utils::to_tensor_safe(out_dict.at("out"));
}
else
{
y_pred = torch_utils::to_tensor_safe(
rank_module.forward(in_vals));
}

// sanity check
if (!y_pred.defined() || y_pred.numel() == 0)
throw MLLibInternalException(
"The model returned an empty tensor");
out_val = rank_module.forward(in_vals);

// Compute loss
Tensor loss = rank_tloss.loss(y_pred, y, in_vals);
Tensor loss = rank_tloss.loss(out_val, y, in_vals);

if (iter_size > 1)
loss /= iter_size;
if (gpu_count > 1)
loss /= static_cast<double>(gpu_count);
if (loss_divider != 1)
loss = loss / loss_divider;

// Backward
loss.backward(
Expand All @@ -969,8 +965,31 @@ namespace dd

#pragma omp critical
{
// Retain loss for statistics
// Retain loss and useful values for statistics
train_loss += loss_val;

if (out_val.isGenericDict())
{
auto out_dict = out_val.toGenericDict();
for (const auto &e : out_dict)
{
std::string key = e.key().toStringRef();
if (!e.value().isTensor())
continue;
auto val_t = e.value().toTensor();

// all scalar values are considered as metrics
if (val_t.numel() != 1)
continue;
double value = val_t.item<double>();
if (loss_divider != 1)
value /= loss_divider;
if (sub_losses.find(key) != sub_losses.end())
sub_losses[key] += value;
else
sub_losses[key] = value;
}
}
}
}

Expand Down Expand Up @@ -1084,14 +1103,20 @@ namespace dd
this->add_meas_per_iter("elapsed_time_ms", elapsed_time_ms);
this->add_meas_per_iter("learning_rate", base_lr);
this->add_meas_per_iter("train_loss", train_loss);
for (auto e : sub_losses)
{
this->add_meas(e.first, e.second);
this->add_meas_per_iter(e.first, e.second);
}
int64_t elapsed_it = it + 1;
if (log_batch_period != 0 && elapsed_it % log_batch_period == 0)
{
this->_logger->info("Iteration {}/{}: loss is {}", elapsed_it,
iterations, train_loss);
for (auto e : sub_losses)
this->_logger->info("\t{}={}", e.first, e.second);
}
last_it_time = 0;
train_loss = 0;

if ((elapsed_it % test_interval == 0 && eval_dataset.size() != 0)
|| elapsed_it == iterations)
Expand All @@ -1106,9 +1131,13 @@ namespace dd
= duration_cast<milliseconds>(steady_clock::now() - tstart)
.count();

APIData meas_obj = meas_out.getobj("measure");
for (const auto &e : sub_losses)
meas_obj.add(e.first, e.second);
meas_out.add("measure", meas_obj);

for (size_t i = 0; i < eval_dataset.size(); ++i)
{
APIData meas_obj;
if (i == 0)
meas_obj = meas_out.getobj("measure");
else
Expand Down Expand Up @@ -1168,6 +1197,9 @@ namespace dd
out = meas_out;
}

train_loss = 0;
sub_losses.clear();

if ((save_period != 0 && elapsed_it % save_period == 0)
|| elapsed_it == iterations)
{
Expand Down
2 changes: 2 additions & 0 deletions src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ namespace dd
bool _bbox = false; /**< select detection type problem */
bool _segmentation = false; /**< select segmentation type problem */
std::string _loss = ""; /**< selected loss*/
double _reg_weight
= 1; /**< for detection models, weight for bbox regression loss. */

APIData _template_params; /**< template parameters, for recurrent and
native models*/
Expand Down
34 changes: 32 additions & 2 deletions src/backends/torch/torchloss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,51 @@

namespace dd
{
torch::Tensor TorchLoss::reloss(torch::Tensor y_pred)
torch::Tensor TorchLoss::reloss(c10::IValue y_pred)
{
return loss(y_pred, _y, _ivx);
}

torch::Tensor TorchLoss::loss(torch::Tensor y_pred, torch::Tensor y,
torch::Tensor TorchLoss::loss(c10::IValue model_out, torch::Tensor y,
std::vector<c10::IValue> &ivx)
{
// blow memorize to be able to redo loss call (in case of solver.sam)
_y = y;
_ivx = ivx;
torch::Tensor x = ivx[0].toTensor();

torch::Tensor y_pred;
torch::Tensor loss;

if (model_out.isGenericDict())
{
auto out_dict = model_out.toGenericDict();
if (_segmentation)
y_pred = torch_utils::to_tensor_safe(out_dict.at("out"));
else if (_loss == "yolox")
{
torch::Tensor iou_loss
= torch_utils::to_tensor_safe(out_dict.at("iou_loss"));
torch::Tensor l1_loss
= torch_utils::to_tensor_safe(out_dict.at("l1_loss"));
torch::Tensor conf_loss
= torch_utils::to_tensor_safe(out_dict.at("conf_loss"));
torch::Tensor cls_loss
= torch_utils::to_tensor_safe(out_dict.at("cls_loss"));
y_pred = iou_loss * _reg_weight + l1_loss + conf_loss + cls_loss;
}
else // _model_loss = true
y_pred = torch_utils::to_tensor_safe(out_dict.at("total_loss"));
}
else
{
y_pred = torch_utils::to_tensor_safe(model_out);
}

// sanity check
if (!y_pred.defined() || y_pred.numel() == 0)
throw MLLibInternalException("The model returned an empty tensor");

if (_model_loss)
{
loss = y_pred;
Expand Down
12 changes: 8 additions & 4 deletions src/backends/torch/torchloss.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,20 @@ namespace dd
TorchLoss(std::string loss, bool model_loss, bool seq_training,
bool timeserie, bool regression, bool classification,
bool segmentation, torch::Tensor class_weights,
TorchModule &module, std::shared_ptr<spdlog::logger> logger)
double reg_weight, TorchModule &module,
std::shared_ptr<spdlog::logger> logger)
: _loss(loss), _model_loss(model_loss), _seq_training(seq_training),
_timeserie(timeserie), _regression(regression),
_classification(classification), _segmentation(segmentation),
_class_weights(class_weights), _logger(logger)
_class_weights(class_weights), _reg_weight(reg_weight),
_logger(logger)
{
_native = module._native;
}

torch::Tensor loss(torch::Tensor y_pred, torch::Tensor y,
torch::Tensor loss(c10::IValue model_out, torch::Tensor target,
std::vector<c10::IValue> &x);
torch::Tensor reloss(torch::Tensor y_pred);
torch::Tensor reloss(c10::IValue model_out);

std::vector<c10::IValue> getLastInputs()
{
Expand All @@ -73,6 +75,8 @@ namespace dd
bool _classification;
bool _segmentation;
torch::Tensor _class_weights = {};
double _reg_weight = 1; /** < on detection models, weight to apply to bbox
regression loss */
std::shared_ptr<NativeModule> _native;
std::shared_ptr<spdlog::logger> _logger;
torch::Tensor _y_pred;
Expand Down
6 changes: 4 additions & 2 deletions src/backends/torch/torchmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ namespace dd
* If more than one device is given, this method performs a multigpu
* forward.
*/
c10::IValue forward_on_devices(std::vector<c10::IValue> source,
const std::vector<torch::Device> &devices);
[[deprecated("Multigpu does not need this method anymore, cf "
"torchlib.cc:train()")]] c10::IValue
forward_on_devices(std::vector<c10::IValue> source,
const std::vector<torch::Device> &devices);

/**
* \brief forward (inference) until extract_layer, return value of
Expand Down
7 changes: 7 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,13 @@ if (USE_TORCH)
"fasterrcnn_train_torch_bs2.tar.gz"
"fasterrcnn_train_torch"
)
DOWNLOAD_DATASET(
"Torchvision training YoloX Model"
"https://www.deepdetect.com/dd/examples/torch/yolox_train_torch.tar.gz"
"examples/torch"
"yolox_train_torch.tar.gz"
"yolox_train_torch"
)
DOWNLOAD_DATASET(
"Torchvision inference DeepLabV3 Resnet50 model"
"https://www.deepdetect.com/dd/examples/torch/deeplabv3_torch.tar.gz"
Expand Down
Loading

0 comments on commit bea7cb4

Please sign in to comment.