forked from llcing/VGG_dml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Faster_CUB.py
65 lines (45 loc) · 1.78 KB
/
Faster_CUB.py
1
from __future__ import absolute_importimport torchvision.datasets as datasetsfrom torchvision import transformsimport osimport os.path as osproot = '/opt/intern/users/xunwang/DataSet/CUB_200_2011/train/'Dirs = os.listdir(root)for i, dir_ in enumerate(Dirs): images = os.listdir(osp.join(root, dir_)) labels = len(images)*[i] images = [osp.join(root, dir_, img) for img in images] with open('listfile.txt', 'w') as filehandle: filehandle.writelines("%s\n" % place for place in images) if i == 3: breakprint(images) class CUB200: def __init__(self, root, train=True, test=True, transform=None): # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if transform is None: transform = [transforms.Compose([ # transforms.CovertBGR(), transforms.Resize(256), transforms.RandomResizedCrop(scale=(0.16, 1), size=224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), transforms.Compose([ # transforms.CovertBGR(), transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])] if root is None: root = '/opt/intern/users/xunwang/DataSet/CUB_200_2011' traindir = os.path.join(root, 'train') testdir = os.path.join(root, 'test') if train: self.train = datasets.ImageFolder(traindir, transform[0]) if test: self.test = datasets.ImageFolder(testdir, transform[1])