From 542bcb49870c82d2bccfd1bf68ac2eaa76e30846 Mon Sep 17 00:00:00 2001 From: Guillaume Infantes Date: Fri, 10 Dec 2021 17:19:11 +0100 Subject: [PATCH] feat(torch): dice loss https://arxiv.org/abs/1707.03237 --- docs/api.md | 2 +- src/backends/torch/torchloss.cc | 69 +++++++++++++++++++++++++- src/backends/torch/torchloss.h | 1 + tests/ut-torchapi.cc | 85 +++++++++++++++++++++++++++++++++ 4 files changed, 154 insertions(+), 3 deletions(-) diff --git a/docs/api.md b/docs/api.md index 301866bb7..231b03287 100644 --- a/docs/api.md +++ b/docs/api.md @@ -271,7 +271,7 @@ mirror | bool | yes | false | finetuning | bool | yes | false | Whether to prepare neural net template for finetuning (requires `weights`) db | bool | yes | false | whether to set a database as input of neural net, useful for handling large datasets and training in constant-memory (requires `mlp` or `convnet`) scaling_temperature | real | yes | 1.0 | sets the softmax temperature of an existing network (e.g. useful for model calibration) -loss | string | yes | N/A | Special network losses, from `dice`, `dice_multiclass`, `dice_weighted`, `dice_weighted_batch` or `dice_weighted_all`, useful for image segmentation, and `L1` or `L2`, useful for time-series via `csvts` connector +loss | string | yes | N/A | Special network losses, from `dice` (direct IOU maximization), `dice_multiclass` (same as dice for torch backend, different implemtation for caffe backend), `dice_weighted` (dice augmented with inter-class weighting based on image stats), `dice_weighted_batch` (dice augmented with inter-class weighting based on batch stats) or `dice_weighted_all` (dice augmented with inter-class weighting based on running stats over all seen data), useful for image segmentation, and `L1` or `L2`, useful for time-series via `csvts` connector ssd_expand_prob | float | yes | between 0 and 1, probability of expanding the image (to improve detection of small/very small objects) ssd_max_expand_ratio | float | yes | bbox zoom out ratio, e.g. 4.0 ssd_mining_type | str | yes | N/A | "HARD_EXAMPLE" or "MAX_NEGATIVE" diff --git a/src/backends/torch/torchloss.cc b/src/backends/torch/torchloss.cc index 05f02be6e..83ed061ed 100644 --- a/src/backends/torch/torchloss.cc +++ b/src/backends/torch/torchloss.cc @@ -21,6 +21,7 @@ #include "torchloss.h" #pragma GCC diagnostic pop +#include namespace dd { @@ -88,8 +89,72 @@ namespace dd } else if (_segmentation) { - loss = torch::nn::functional::cross_entropy( - y_pred, y.squeeze(1).to(torch::kLong)); // TODO: options + if (_loss.empty()) + { + + loss = torch::nn::functional::cross_entropy( + y_pred, y.squeeze(1).to(torch::kLong)); // TODO: options + } + else if (_loss == "dice" || _loss == "dice_multiclass" + || _loss == "dice_weighted" || _loss == "dice_weighted_batch" + || _loss == "dice_weighted_all") + { + // see https://arxiv.org/abs/1707.03237 + double smooth = 1e-7; + torch::Tensor y_true_f + = torch::one_hot(y.to(torch::kInt64), y_pred.size(1)) + .squeeze(1) + .permute({ 0, 3, 1, 2 }) + .flatten(2) + .to(torch::kFloat32); + torch::Tensor y_pred_f = torch::flatten(torch::sigmoid(y_pred), 2); + + torch::Tensor intersect; + torch::Tensor denom; + + if (_loss == "dice" || _loss == "dice_multiclass") + { + intersect = y_true_f * y_pred_f; + denom = y_true_f + y_pred_f; + } + else if (_loss == "dice_weighted") + { + torch::Tensor sum = torch::sum(y_true_f, { 2 }) + 1.0; + torch::Tensor weights = 1.0 / sum / sum; + intersect = torch::sum(y_true_f * y_pred_f, { 2 }) * weights; + denom = torch::sum(y_true_f + y_pred_f, { 2 }) * weights; + } + else if (_loss == "dice_weighted_batch" + || _loss == "dice_weighted_all") + { + torch::Tensor sum + = torch::sum(y_true_f, std::vector({ 0, 2 })) + + 1.0; + torch::Tensor weights = 1.0 / sum / sum; + if (_loss == "dice_weighted_all") + { + if (_num_batches == 0) + _class_weights = weights; + else + { + weights = (_class_weights * _num_batches + weights) + / (_num_batches + 1); + _class_weights = weights; + } + _num_batches++; + } + intersect = torch::sum(y_true_f * y_pred_f, + std::vector({ 0, 2 })) + * weights; + denom = torch::sum(y_true_f + y_pred_f, + std::vector({ 0, 2 })) + * weights; + } + + return 1.0 - torch::mean(2.0 * intersect / (denom + smooth)); + } + else + throw MLLibBadParamException("unknown loss: " + _loss); } else { diff --git a/src/backends/torch/torchloss.h b/src/backends/torch/torchloss.h index 8b8610790..dd64d57c4 100644 --- a/src/backends/torch/torchloss.h +++ b/src/backends/torch/torchloss.h @@ -78,6 +78,7 @@ namespace dd torch::Tensor _y_pred; torch::Tensor _y; std::vector _ivx; + long int _num_batches = 0; }; } #endif diff --git a/tests/ut-torchapi.cc b/tests/ut-torchapi.cc index 23b953fc7..7a37fc5e4 100644 --- a/tests/ut-torchapi.cc +++ b/tests/ut-torchapi.cc @@ -826,6 +826,91 @@ TEST(torchapi, service_train_image_segmentation) fileops::remove_dir(deeplabv3_train_repo + "test_0.lmdb"); } +TEST(torchapi, service_train_image_segmentation_dice) +{ + setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true); + torch::manual_seed(torch_seed); + at::globalContext().setDeterministicCuDNN(true); + + // Create service + JsonAPI japi; + std::string sname = "imgserv"; + std::string jstr + = "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":" + "\"supervised\",\"model\":{\"repository\":\"" + + deeplabv3_train_repo + + "\"},\"parameters\":{\"input\":{\"connector\":\"image\"," + "\"width\":480,\"height\":480,\"db\":true,\"segmentation\":true}," + "\"mllib\":{\"nclasses\":" + "13,\"gpu\":true,\"segmentation\":true,\"loss\":\"dice_weighted_" + "all\"}}" + "}"; + std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + ASSERT_EQ(created_str, joutstr); + + // Train + std::string jtrainstr + = "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{" + "\"mllib\":{\"solver\":{\"iterations\":" + + iterations_deeplabv3 + ",\"base_lr\":" + torch_lr + + ",\"iter_size\":1,\"solver_type\":\"ADAM\",\"test_" + "interval\":100},\"net\":{\"batch_size\":4}," + "\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224," + "\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":" + "true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true," + "\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0." + "01}}," + "\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true," + "\"segmentation\":true,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406]" + ",\"std\":[0.229,0.224,0.225]}," + "\"output\":{\"measure\":[\"meaniou\",\"acc\"]}},\"data\":[\"" + + deeplabv3_train_data + "\",\"" + deeplabv3_test_data + "\"]}"; + joutstr = japi.jrender(japi.service_train(jtrainstr)); + JDoc jd; + std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(201, jd["status"]["code"]); + + ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() <= 1) << "accuracy"; + ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() >= 0.007) + << "accuracy good"; + ASSERT_TRUE(jd["body"]["measure"]["meaniou"].GetDouble() <= 1) << "meaniou"; + + std::string jpredictstr + = "{\"service\":\"imgserv\",\"parameters\":{" + "\"input\":{\"height\":480," + "\"width\":480,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406],\"std\":[" + "0.229,0.224,0.225]},\"output\":{\"segmentation\":true, " + "\"confidences\":[\"best\"]}},\"data\":[\"" + + deeplabv3_test_image + "\"]}"; + + joutstr = japi.jrender(japi.service_predict(jpredictstr)); + std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(200, jd["status"]["code"]); + ASSERT_TRUE(jd["body"]["predictions"].IsArray()); + + std::unordered_set lfiles; + fileops::list_directory(deeplabv3_train_repo, true, false, false, lfiles); + for (std::string ff : lfiles) + { + if (ff.find("checkpoint") != std::string::npos + || ff.find("solver") != std::string::npos) + remove(ff.c_str()); + } + ASSERT_TRUE(!fileops::file_exists(deeplabv3_train_repo + "checkpoint-" + + iterations_deeplabv3 + ".ptw")); + ASSERT_TRUE(!fileops::file_exists(deeplabv3_train_repo + "checkpoint-" + + iterations_deeplabv3 + ".pt")); + + fileops::clear_directory(deeplabv3_train_repo + "train.lmdb"); + fileops::clear_directory(deeplabv3_train_repo + "test_0.lmdb"); + fileops::remove_dir(deeplabv3_train_repo + "train.lmdb"); + fileops::remove_dir(deeplabv3_train_repo + "test_0.lmdb"); +} + TEST(torchapi, service_publish_trained_model) { setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);