diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..6e969d2 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +*.swp +**/__pycache__/** +.idea/* +ckpt/ +*.pth +*.log +*.txt +.dockerignore diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..cb3c2f5 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel + +WORKDIR / + +RUN pip install timm==0.5.4 + +COPY /pretrain/requirements.txt / +RUN pip install --no-cache-dir -r requirements.txt + +CMD ["bash"] diff --git a/downstream_imagenet/arg.py b/downstream_imagenet/arg.py index a7435bf..956b63a 100644 --- a/downstream_imagenet/arg.py +++ b/downstream_imagenet/arg.py @@ -12,6 +12,11 @@ HP_DEFAULT_NAMES = ['bs', 'ep', 'wp_ep', 'opt', 'base_lr', 'lr_scale', 'wd', 'mixup', 'rep_aug', 'drop_path', 'ema'] HP_DEFAULT_VALUES = { + 'convnext_atto': (1024, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999), + 'convnext_femto': (1024, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999), + 'convnext_pico': (512, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999), + 'convnext_nano': (512, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999), + 'convnext_tiny': (256, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999), 'convnext_small': (4096, 400, 20, 'adam', 0.0002, 0.7, 0.01, 0.8, 3, 0.3, 0.9999), 'convnext_base': (4096, 400, 20, 'adam', 0.0001, 0.7, 0.01, 0.8, 3, 0.4, 0.9999), 'convnext_large': (4096, 200, 10, 'adam', 0.0001, 0.7, 0.02, 0.8, 3, 0.5, 0.9999), diff --git a/downstream_imagenet/models/convnext_official.py b/downstream_imagenet/models/convnext_official.py index 2f47c48..9955ea8 100644 --- a/downstream_imagenet/models/convnext_official.py +++ b/downstream_imagenet/models/convnext_official.py @@ -141,19 +141,67 @@ def forward(self, x): x = self.weight[:, None, None] * x + self.bias[:, None, None] return x - +# pretrained weights available at https://github.com/facebookresearch/ConvNeXt, https://github.com/facebookresearch/ConvNeXt-V2 model_urls = { - "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", - "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", - "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", - "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", - "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", - "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", - "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", - "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", - "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", + "convnext_atto_1k" : "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt", #ConvNeXt-V2 ImageNet1k fine-tuning + "convnext_femto_1k" : "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt", #ConvNeXt-V2 ImageNet1k fine-tuning + "convnext_pico_1k" : "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt", #ConvNeXt-V2 ImageNet1k fine-tuning + "convnext_nano_1k" : "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt", #ConvNeXt-V2 ImageNet1k fine-tuning + "convnext_huge_1k" : "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt", #ConvNeXt-V2 ImageNet1k fine-tuning + "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", #ConvNeXt-V1 supervised training + "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", #ConvNeXt-V1 supervised training + "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", #ConvNeXt-V1 supervised training + "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", #ConvNeXt-V1 supervised training + "convnext_nano_22k" : "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt", #ConvNeXt-V2 ImageNet22k fine-tuning + "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", #ConvNeXt-V1 supervised training + "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", #ConvNeXt-V1 supervised training + "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", #ConvNeXt-V1 supervised training + "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", #ConvNeXt-V1 supervised training + "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", #ConvNeXt-V1 supervised training } +@register_model +def convnext_atto(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) + if pretrained: + if in_22k: + raise NotImplementedError("Add weights to load.") + url = model_urls['convnext_atto_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def convnext_femto(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) + if pretrained: + if in_22k: + raise NotImplementedError("Add weights to load.") + url = model_urls['convnext_atto_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def convnext_pico(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) + if pretrained: + if in_22k: + raise NotImplementedError("Add weights to load.") + url = model_urls['convnext_atto_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + model.load_state_dict(checkpoint["model"]) + return model + +@register_model +def convnext_nano(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) + if pretrained: + url = model_urls['convnext_nano_22k'] if in_22k else model_urls['convnext_nano_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + model.load_state_dict(checkpoint["model"]) + return model + @register_model def convnext_tiny(pretrained=False,in_22k=False, **kwargs): model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) @@ -199,3 +247,14 @@ def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") model.load_state_dict(checkpoint["model"]) return model + +@register_model +def convnext_huge(pretrained=False, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) + if pretrained: + if in_22k: + raise NotImplementedError("Add weights to load.") + url = model_urls['convnext_huge_1k'] + checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + return model \ No newline at end of file diff --git a/pretrain/models/__init__.py b/pretrain/models/__init__.py index c632e73..0fde90f 100644 --- a/pretrain/models/__init__.py +++ b/pretrain/models/__init__.py @@ -37,9 +37,16 @@ def _ex_repr(self): 'resnet101': dict(drop_path_rate=0.08), 'resnet152': dict(drop_path_rate=0.10), 'resnet200': dict(drop_path_rate=0.15), + 'convnext_atto': dict(sparse=True, drop_path_rate=0.2), + 'convnext_femto': dict(sparse=True, drop_path_rate=0.2), + 'convnext_pico': dict(sparse=True, drop_path_rate=0.2), + 'convnext_nano': dict(sparse=True, drop_path_rate=0.2), + 'convnext_tiny': dict(sparse=True, drop_path_rate=0.2), 'convnext_small': dict(sparse=True, drop_path_rate=0.2), 'convnext_base': dict(sparse=True, drop_path_rate=0.3), 'convnext_large': dict(sparse=True, drop_path_rate=0.4), + 'convnext_xlarge': dict(sparse=True, drop_path_rate=0.4), + 'convnext_huge': dict(sparse=True, drop_path_rate=0.4), } for kw in pretrain_default_model_kwargs.values(): kw['pretrained'] = False diff --git a/pretrain/models/convnext.py b/pretrain/models/convnext.py index 6b0169e..ea6b06b 100644 --- a/pretrain/models/convnext.py +++ b/pretrain/models/convnext.py @@ -99,27 +99,51 @@ def get_classifier(self): def extra_repr(self): return f'drop_path_rate={self.drop_path_rate}, layer_scale_init_value={self.layer_scale_init_value:g}' +@register_model +def convnext_atto(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) + return model @register_model -def convnext_tiny(pretrained=False, in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) +def convnext_femto(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) return model +@register_model +def convnext_pico(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) + +@register_model +def convnext_nano(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) + return model + +@register_model +def convnext_tiny(pretrained=False,in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + return model @register_model def convnext_small(pretrained=False, in_22k=False, **kwargs): model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) return model - @register_model def convnext_base(pretrained=False, in_22k=False, **kwargs): model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) return model - @register_model def convnext_large(pretrained=False, in_22k=False, **kwargs): model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) return model +@register_model +def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + return model + +@register_model +def convnext_huge(pretrained=False, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) + return model \ No newline at end of file diff --git a/pretrain/pretrain.sh b/pretrain/pretrain.sh new file mode 100755 index 0000000..c133cec --- /dev/null +++ b/pretrain/pretrain.sh @@ -0,0 +1,52 @@ +# encoder hyperparameters +MODEL="convnext_atto" +INPUT_SIZE=64 +SBN=True + +# SparK hyperparameters +MASK=0.6 + +# data hyperparameters +BATCH_SIZE=2 +DATALOADER_WORKERS=4 + +# pre-training hyperparameters +DP=0.0 +BASE_LR=2e-4 +WD=0.04 +WDE=0.2 +EP=1600 +WP_EP=40 +CLIP=5 +OPT='lamb' +ADA=0. + +# environment +EXP_NAME="${model}_mask_${MASK}_is_${INPUT_SIZE}_bs_${BATCH_SIZE}_baselr_${BASE_LR}_epochs_${EP}_opt_${OPT}" +EXP_DIR="/log_dir/${exp_name}" #will be created if not exists +DATA_PATH='../../imagenet100' +INIT_WEIGTH='' # use some checkpoint as model weight initialization; ONLY load model weights +RESUME_FROM='' # resume the experiment from some checkpoint.pth; load model weights, optimizer states, and last epoch + + +python main.py \ + --model $MODEL \ + --input_size $INPUT_SIZE \ + --sbn $SBN \ + --mask $MASK \ + --batch_size $BATCH_SIZE \ + --dataloader_workers $DATALOADER_WORKERS \ + --dp $DP \ + --base_lr $BASE_LR \ + --wd $WD \ + --wde $WDE \ + --ep $EP \ + --wp_ep $WP_EP \ + --clip $CLIP \ + --opt $OPT \ + --ada $ADA \ + --exp_name "$EXP_NAME" \ + --exp_dir "$EXP_DIR" \ + --data_path "$DATA_PATH" \ + --init_weight "$INIT_WEIGTH" \ + --resume_from "$RESUME_FROM" \ No newline at end of file diff --git a/pretrain/requirements.txt b/pretrain/requirements.txt index 896e276..2ab13f8 100644 --- a/pretrain/requirements.txt +++ b/pretrain/requirements.txt @@ -4,3 +4,4 @@ Pillow typed-argument-parser timm==0.5.4 tensorboardx +tensorboard \ No newline at end of file