-
Notifications
You must be signed in to change notification settings - Fork 8
/
model_attention.py
65 lines (52 loc) · 1.64 KB
/
model_attention.py
1
# coding=utf-8from __future__ import absolute_import, print_functionimport timeimport argparseimport osimport sysimport torch.utils.datafrom torch.backends import cudnnfrom torch.autograd import Variableimport modelsimport lossesfrom utils import RandomIdentitySampler, mkdir_if_missing, logging, displayimport DataSetimport numpy as npcudnn.benchmark = True# import os# os.environ["CUDA_VISIBLE_DEVICES"] = '6'# ABE = models.create('vgg_attention')def load_parameter(ABE): vgg = models.create('vgg', pretrained=True) vgg_dict = vgg.state_dict() vgg_name = [k for k, _ in vgg_dict.items() if k in vgg_dict] # print(vgg_name) # ABE_dict = ABE.state_dict() ABE_dict = ABE.state_dict() ABE_name = [k for k, _ in ABE_dict.items() if k in ABE_dict] # print(ABE_name) Num_overlap_layers = 78 for i in range(Num_overlap_layers): ABE_dict[ABE_name[i]] = vgg_dict[vgg_name[i]] ABE.load_state_dict(ABE_dict) return ABE# print(ABE.state_dict()[ABE_name[0]][0][0])# print(vgg.state_dict()[vgg_name[0]][0][0])model = models.create(args.net)model_dict = model.state_dict()model_name = [k for k, _ in model_dict.items() if k in model_dict]model_weights = [v for k, v in model_dict.items() if k in model_dict]print(model_name)n = 0for z in model_weights: n += 1 print(z.shape)z = np.array([torch.abs(torch.sum(k * l)).cpu().numpy() for l in V])# model_dict.update(pretrained_dict)# model_dict['features.34.weight'].shape# model.features = torch.nn.Sequential(# model.features,# torch.nn.MaxPool2d(7),# # torch.nn.BatchNorm2d(512),# torch.nn.Dropout(p=0.01)# )mode