From d20aa4e22ec87d307d693903f515d226a7209874 Mon Sep 17 00:00:00 2001 From: yue kun Date: Sun, 30 Oct 2022 15:54:48 +0800 Subject: [PATCH] rm some useless code --- OCR/MGP-STR/test_final.py | 11 +++-------- OCR/MGP-STR/train_final_dist.py | 5 +---- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/OCR/MGP-STR/test_final.py b/OCR/MGP-STR/test_final.py index ad0e362..a06b623 100644 --- a/OCR/MGP-STR/test_final.py +++ b/OCR/MGP-STR/test_final.py @@ -239,12 +239,7 @@ def validation(model, criterion, evaluation_loader, converter, opt): if out_pred == gt: out_n_correct += 1 - # calculate confidence score (= multiply of pred_max_prob) - try: - confidence_score = char_preds_max_prob[index].cumprod(dim=0)[-1] - except: - confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) - confidence_score_list.append(confidence_score) + confidence_score_list.append(char_confidence_score) elif opt.Transformer in ["char-str"]: attens, char_preds = model(image, is_eval=True) # final @@ -393,8 +388,8 @@ def test(opt): _, accuracy_by_best_model, _, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) log.write(eval_data_log) - print(f'{accuracy_by_best_model:0.3f}') - log.write(f'{accuracy_by_best_model:0.3f}\n') + print(f'{accuracy_by_best_model[0]:0.3f}') + log.write(f'{accuracy_by_best_model[0]:0.3f}\n') log.close() # https://github.com/clovaai/deep-text-recognition-benchmark/issues/125 diff --git a/OCR/MGP-STR/train_final_dist.py b/OCR/MGP-STR/train_final_dist.py index a251f31..31db029 100644 --- a/OCR/MGP-STR/train_final_dist.py +++ b/OCR/MGP-STR/train_final_dist.py @@ -79,10 +79,7 @@ def train(opt): if opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') - if opt.FT: - model.load_state_dict(torch.load(opt.saved_model, map_location='cpu'), strict=True) - else: - model.load_state_dict(torch.load(opt.saved_model, map_location='cpu'), strict=True) + model.load_state_dict(torch.load(opt.saved_model, map_location='cpu'), strict=True) """ setup loss """ criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0