-
Notifications
You must be signed in to change notification settings - Fork 1
/
bert_vit.py
122 lines (106 loc) · 4.89 KB
/
bert_vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
BertModel,
BertConfig,
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_outputs import SequenceClassifierOutput
from .modules import EncoderRNN, BiAttention, get_aggregated
from transformers import ViTModel, ViTConfig
class BertConfigForWebshop(PretrainedConfig):
model_type = "bert"
def __init__(
self,
pretrained_bert=True,
image=False,
**kwargs
):
self.pretrained_bert = pretrained_bert
self.image = image
super().__init__(**kwargs)
class BertVitForWebshop(PreTrainedModel):
config_class = BertConfigForWebshop
def __init__(self, config):
super().__init__(config)
if config.pretrained_bert:
self.bert = BertModel.from_pretrained('bert-base-uncased')
else:
self.bert = BertModel(config)
self.bert.resize_token_embeddings(30526)
self.attn = BiAttention(768, 0.0)
self.linear_1 = nn.Linear(768 * 4, 768)
self.relu = nn.ReLU()
self.linear_2 = nn.Linear(768, 1)
if config.image:
vit_config = ViTConfig.from_pretrained('google/vit-base-patch16-224-in21k')
self.vit = ViTModel(vit_config)
for param in self.vit.parameters():
param.requires_grad = False
else:
self.vit = None
# for state value prediction, used in RL
self.linear_3 = nn.Sequential(
nn.Linear(768, 128),
nn.LeakyReLU(),
nn.Linear(128, 1),
)
def forward(self, state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, raw_images=None, labels=None):
sizes = sizes.tolist()
state_rep = self.bert(state_input_ids, attention_mask=state_attention_mask)[0]
if raw_images is not None and self.vit is not None:
image_emb = self.vit(pixel_values=raw_images).last_hidden_state
image_emb = image_emb[:, 0, :] # Take the [CLS] token
state_rep = torch.cat([image_emb.unsqueeze(dim=1), state_rep], dim=1)
state_attention_mask = torch.cat([state_attention_mask[:, :1], state_attention_mask], dim=1)
action_rep = self.bert(action_input_ids, attention_mask=action_attention_mask)[0]
state_rep = torch.cat([state_rep[i:i+1].repeat(j, 1, 1) for i, j in enumerate(sizes)], dim=0)
state_attention_mask = torch.cat([state_attention_mask[i:i+1].repeat(j, 1) for i, j in enumerate(sizes)], dim=0)
act_lens = action_attention_mask.sum(1).tolist()
state_action_rep = self.attn(action_rep, state_rep, state_attention_mask)
state_action_rep = self.relu(self.linear_1(state_action_rep))
act_values = get_aggregated(state_action_rep, act_lens, 'mean')
act_values = self.linear_2(act_values).squeeze(1)
logits = [F.log_softmax(_, dim=0) for _ in act_values.split(sizes)]
loss = None
if labels is not None:
loss = - sum([logit[label] for logit, label in zip(logits, labels)]) / len(logits)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
def rl_forward(self, state_batch, act_batch, value=False, q=False, act=False):
act_values = []
act_sizes = []
values = []
for state, valid_acts in zip(state_batch, act_batch):
with torch.set_grad_enabled(not act):
state_ids = torch.tensor([state.obs]).cuda()
state_mask = (state_ids > 0).int()
act_lens = [len(_) for _ in valid_acts]
act_ids = [torch.tensor(_) for _ in valid_acts]
act_ids = nn.utils.rnn.pad_sequence(act_ids, batch_first=True).cuda()
act_mask = (act_ids > 0).int()
act_size = torch.tensor([len(valid_acts)]).cuda()
if self.image_linear is not None:
images = [state.image_feat]
images = [torch.zeros(512) if _ is None else _ for _ in images]
images = torch.stack(images).cuda() # BS x 512
else:
images = None
logits = self.forward(state_ids, state_mask, act_ids, act_mask, act_size, images=images).logits[0]
act_values.append(logits)
act_sizes.append(len(valid_acts))
if value:
v = self.bert(state_ids, state_mask)[0]
values.append(self.linear_3(v[0][0]))
act_values = torch.cat(act_values, dim=0)
act_values = torch.cat([F.log_softmax(_, dim=0) for _ in act_values.split(act_sizes)], dim=0)
# Optionally, output state value prediction
if value:
values = torch.cat(values, dim=0)
return act_values, act_sizes, values
else:
return act_values, act_sizes