Skip to content

Commit

Permalink
Adjust to not scan the entire training set for each batch. This is ju…
Browse files Browse the repository at this point in the history
…st a bad hack, see nearai#5 for discussion of the issue.
  • Loading branch information
fac2003 committed Aug 2, 2018
1 parent 63e6768 commit ccf964d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
20 changes: 16 additions & 4 deletions examples/snli/spinn-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from torchtext import data
from torchtext import datasets
from torchtext.data import Example

import torchfold

Expand Down Expand Up @@ -50,7 +51,11 @@ def __init__(self, n_classes, size, n_words):
self.out = nn.Linear(size, n_classes)

def leaf(self, word_id):
return self.embeddings(word_id), torch.Tensor(word_id.size()[0], self.size, dtype=torch.float32)
embedded = self.embeddings(word_id)
return embedded, \
torch.zeros((word_id.size(0),
self.size))\
.type(dtype=torch.float32)

def children(self, left_h, left_c, right_h, right_c):
return self.tree_lstm((left_h, left_c), (right_h, right_c))
Expand Down Expand Up @@ -125,7 +130,7 @@ def main():
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train, dev, test), batch_size=args.batch_size, device=0 if args.cuda else -1)
print("Done.")
model = SPINN(3, 50, 100)
model = SPINN(3, 500, 10000)
criterion = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=0.01)
device = torch.device('cuda' if args.cuda else 'cpu')
Expand All @@ -140,15 +145,22 @@ def main():

all_logits, all_labels = [], []
fold = torchfold.Fold(device=device)
count=0
# TODO: incorrect logic here, the for loop goes through the entire dataset, not the batch:
for example in batch.dataset:
print("example:"+str(example))
#print("example:"+str(example))
tree = Tree(example, inputs.vocab, answers.vocab)
if args.fold:
all_logits.append(encode_tree_fold(fold, tree))
else:
all_logits.append(encode_tree_regular(model, tree))
all_labels.append(tree.label)
# We use the following BAD workaround to test the later part of folding and training:
count+=1
if count>args.batch_size:
# we stop after seeing batchsize examples, but they are not from the batch, but from the
# entire dataset, presumably never shuffled.
break

if args.fold:
res = fold.apply(model, [all_logits, all_labels])
Expand All @@ -160,7 +172,7 @@ def main():
loss.backward(); opt.step()

iteration += 1
if iteration % 1 == 1:
if iteration % 10 == 1:
print("Avg. Time: %fs" % ((time.time() - start) / iteration))
print("iteration {} loss:{}".format(iteration, loss))
# iteration = 0
Expand Down
6 changes: 3 additions & 3 deletions torchfold/torchfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _batch_args(self, arg_lists, values):
res.append(arg[0].get(values))
elif all(isinstance(arg_item, int) for arg_item in arg):

var = torch.Tensor(arg,dtype=torch.long).to(self._device)
var = torch.Tensor(arg).type(dtype=torch.long).to(self._device)
#var = Variable(torch.LongTensor(arg), volatile=self.volatile).to(self._device)
res.append(var)
else:
Expand Down Expand Up @@ -220,8 +220,8 @@ def _arg(self, arg):
return arg.tensor
elif isinstance(arg, int):

return torch.Tensor([arg], volatile=self.volatile, dtype=torch.long).to(self._device)
#return Variable(torch.LongTensor([arg]), volatile=self.volatile).to(self._device)
return torch.Tensor([arg], volatile=self.volatile).type(dtype=torch.long).to(self._device)

else:
return arg

Expand Down

0 comments on commit ccf964d

Please sign in to comment.