Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
zylo117 committed Apr 6, 2020
1 parent bf8228b commit 6e01f7e
Show file tree
Hide file tree
Showing 15 changed files with 1,850 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ dmypy.json

# Pyre type checker
.pyre/
.idea/
weights/
87 changes: 87 additions & 0 deletions backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Author: Zylo117

import math

import torch
from torch import nn

from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet
from efficientdet.utils import Anchors



class EfficientDetBackbone(nn.Module):
def __init__(self, num_anchors=9, num_classes=80, compound_coef=0, load_weights=False, **kwargs):
super(EfficientDetBackbone, self).__init__()
self.compound_coef = compound_coef

self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384]
self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8]
self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536]
self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5]
self.anchor_scale = [4, 4, 3, 4, 4, 4, 4, 5]
self.aspect_ratios = [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]
self.num_scales = 3
self.anchor_scale = 4.0
conv_channel_coef = {
# TODO: I have only tested on D0/D2, if you want to try it on other coefficients,
# fill it in with the channels of P3/P4/P5 like this.
2: [48, 120, 352],
}

new_num_anchors = len(kwargs.get('ratios', [])) * len(kwargs.get('scales', []))
if new_num_anchors > 0:
num_anchors = new_num_anchors
else:
num_anchors = len(self.aspect_ratios) * self.num_scales

self.bifpn = nn.Sequential(
*[BiFPN(self.fpn_num_filters[self.compound_coef],
conv_channel_coef[compound_coef],
True if _ == 0 else False) for _ in range(self.fpn_cell_repeats[compound_coef])])

self.num_classes = num_classes
self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
num_layers=self.box_class_repeats[self.compound_coef])
self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
num_classes=num_classes,
num_layers=self.box_class_repeats[self.compound_coef])

self.anchors = Anchors(image_size=self.input_sizes[compound_coef], **kwargs)

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

self.backbone_net = EfficientNet(compound_coef, load_weights)

def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()

def forward(self, inputs):
max_size = inputs.shape[-1]

_, p3, p4, p5 = self.backbone_net(inputs)

features = (p3, p4, p5)
features = self.bifpn(features)

regression = self.regressor(features)
classification = self.classifier(features)
anchors = self.anchors(inputs, inputs.dtype)

return features, regression, classification, anchors

def init_backbone(self, path):
state_dict = torch.load(path)
try:
ret = self.load_state_dict(state_dict, strict=False)
print(ret)
except RuntimeError as e:
print('Ignoring ' + str(e) + '"')
26 changes: 26 additions & 0 deletions efficientdet/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
COCO_CLASSES = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog",
"horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant",
"bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"]

colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), (99, 81, 172), (95, 56, 104), (37, 84, 86), (14, 89, 122),
(80, 7, 65), (10, 102, 25), (90, 185, 109), (106, 110, 132), (169, 158, 85), (188, 185, 26), (103, 1, 17),
(82, 144, 81), (92, 7, 184), (49, 81, 155), (179, 177, 69), (93, 187, 158), (13, 39, 73), (12, 50, 60),
(16, 179, 33), (112, 69, 165), (15, 139, 63), (33, 191, 159), (182, 173, 32), (34, 113, 133), (90, 135, 34),
(53, 34, 86), (141, 35, 190), (6, 171, 8), (118, 76, 112), (89, 60, 55), (15, 54, 88), (112, 75, 181),
(42, 147, 38), (138, 52, 63), (128, 65, 149), (106, 103, 24), (168, 33, 45), (28, 136, 135), (86, 91, 108),
(52, 11, 76), (142, 6, 189), (57, 81, 168), (55, 19, 148), (182, 101, 89), (44, 65, 179), (1, 33, 26),
(122, 164, 26), (70, 63, 134), (137, 106, 82), (120, 118, 52), (129, 74, 42), (182, 147, 112), (22, 157, 50),
(56, 50, 20), (2, 22, 177), (156, 100, 106), (21, 35, 42), (13, 8, 121), (142, 92, 28), (45, 118, 33),
(105, 118, 30), (7, 185, 124), (46, 34, 146), (105, 184, 169), (22, 18, 5), (147, 71, 73), (181, 64, 91),
(31, 39, 184), (164, 179, 33), (96, 50, 18), (95, 15, 106), (113, 68, 54), (136, 116, 112), (119, 139, 130),
(31, 139, 34), (66, 6, 127), (62, 39, 2), (49, 99, 180), (49, 119, 155), (153, 50, 183), (125, 38, 3),
(129, 87, 143), (49, 87, 40), (128, 62, 120), (73, 85, 148), (28, 144, 118), (29, 9, 24), (175, 45, 108),
(81, 175, 64), (178, 19, 157), (74, 188, 190), (18, 114, 2), (62, 128, 96), (21, 3, 150), (0, 6, 95),
(2, 20, 184), (122, 37, 185)]
181 changes: 181 additions & 0 deletions efficientdet/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import os
import torch
import numpy as np

from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
import cv2


class CocoDataset(Dataset):
def __init__(self, root_dir, set='train2017', transform=None):

self.root_dir = root_dir
self.set_name = set
self.transform = transform

self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json'))
self.image_ids = self.coco.getImgIds()

self.load_classes()

def load_classes(self):

# load class names (name -> label)
categories = self.coco.loadCats(self.coco.getCatIds())
categories.sort(key=lambda x: x['id'])

self.classes = {}
self.coco_labels = {}
self.coco_labels_inverse = {}
for c in categories:
self.coco_labels[len(self.classes)] = c['id']
self.coco_labels_inverse[c['id']] = len(self.classes)
self.classes[c['name']] = len(self.classes)

# also load the reverse (label -> name)
self.labels = {}
for key, value in self.classes.items():
self.labels[value] = key

def __len__(self):
return len(self.image_ids)

def __getitem__(self, idx):

img = self.load_image(idx)
annot = self.load_annotations(idx)
sample = {'img': img, 'annot': annot}
if self.transform:
sample = self.transform(sample)
return sample

def load_image(self, image_index):
image_info = self.coco.loadImgs(self.image_ids[image_index])[0]
path = os.path.join(self.root_dir, self.set_name, image_info['file_name'])
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

return img.astype(np.float32) / 255.

def load_annotations(self, image_index):
# get ground truth annotations
annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False)
annotations = np.zeros((0, 5))

# some images appear to miss annotations
if len(annotations_ids) == 0:
return annotations

# parse annotations
coco_annotations = self.coco.loadAnns(annotations_ids)
for idx, a in enumerate(coco_annotations):

# some annotations have basically no width / height, skip them
if a['bbox'][2] < 1 or a['bbox'][3] < 1:
continue

annotation = np.zeros((1, 5))
annotation[0, :4] = a['bbox']
annotation[0, 4] = self.coco_label_to_label(a['category_id'])
annotations = np.append(annotations, annotation, axis=0)

# transform from [x, y, w, h] to [x1, y1, x2, y2]
annotations[:, 2] = annotations[:, 0] + annotations[:, 2]
annotations[:, 3] = annotations[:, 1] + annotations[:, 3]

return annotations

def coco_label_to_label(self, coco_label):
return self.coco_labels_inverse[coco_label]

def label_to_coco_label(self, label):
return self.coco_labels[label]

def num_classes(self):
return 80


def collater(data):
imgs = [s['img'] for s in data]
annots = [s['annot'] for s in data]
scales = [s['scale'] for s in data]

imgs = torch.from_numpy(np.stack(imgs, axis=0))

max_num_annots = max(annot.shape[0] for annot in annots)

if max_num_annots > 0:

annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1

if max_num_annots > 0:
for idx, annot in enumerate(annots):
if annot.shape[0] > 0:
annot_padded[idx, :annot.shape[0], :] = annot
else:
annot_padded = torch.ones((len(annots), 1, 5)) * -1

imgs = imgs.permute(0, 3, 1, 2)

return {'img': imgs, 'annot': annot_padded, 'scale': scales}


class Resizer(object):
"""Convert ndarrays in sample to Tensors."""

def __call__(self, sample, common_size=512):
image, annots = sample['img'], sample['annot']
height, width, _ = image.shape
if height > width:
scale = common_size / height
resized_height = common_size
resized_width = int(width * scale)
else:
scale = common_size / width
resized_height = int(height * scale)
resized_width = common_size

image = cv2.resize(image, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)

new_image = np.zeros((common_size, common_size, 3))
new_image[0:resized_height, 0:resized_width] = image

annots[:, :4] *= scale

return {'img': torch.from_numpy(new_image), 'annot': torch.from_numpy(annots), 'scale': scale}


class Augmenter(object):
"""Convert ndarrays in sample to Tensors."""

def __call__(self, sample, flip_x=0.5):
if np.random.rand() < flip_x:
image, annots = sample['img'], sample['annot']
image = image[:, ::-1, :]

rows, cols, channels = image.shape

x1 = annots[:, 0].copy()
x2 = annots[:, 2].copy()

x_tmp = x1.copy()

annots[:, 0] = cols - x2
annots[:, 2] = cols - x_tmp

sample = {'img': image, 'annot': annots}

return sample


class Normalizer(object):

def __init__(self):
self.mean = np.array([[[0.485, 0.456, 0.406]]])
self.std = np.array([[[0.229, 0.224, 0.225]]])

def __call__(self, sample):
image, annots = sample['img'], sample['annot']

return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots}
Loading

0 comments on commit 6e01f7e

Please sign in to comment.