-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_collator.py
176 lines (146 loc) · 7.47 KB
/
data_collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, NewType, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers.tokenization_utils import PreTrainedTokenizer
import random
class DataCollator(ABC):
"""
A `DataCollator` is responsible for batching
and pre-processing samples of data as requested by the training loop.
"""
@abstractmethod
def collate_batch(self) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
pass
InputDataClass = NewType("InputDataClass", Any)
@dataclass
class DefaultDataCollator(DataCollator):
"""
Very simple data collator that:
- simply collates batches of dict-like objects
- Performs special handling for potential keys named:
- `label`: handles a single value (int or float) per object
- `label_ids`: handles a list of values per object
- does not do any additional preprocessing
i.e., Property names of the input object will be used as corresponding inputs to the model.
See glue and ner for example of how it's useful.
"""
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
# In this method we'll make the assumption that all `features` in the batch
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
first = features[0]
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None:
if type(first.label) is int:
labels = torch.tensor([f.label for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label for f in features], dtype=torch.float)
batch = {"labels": labels}
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else:
batch = {}
# Handling of all other possible attributes.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
return batch
@dataclass
class DataCollatorForLanguageModeling(DataCollator):
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
"""
tokenizer: PreTrainedTokenizer
mlm: bool = True
mlm_probability: float = 0.15
def collate_batch(self, examples) -> Dict[str, torch.Tensor]:
batch, tokens = self._tensorize_batch(examples)
if self.mlm:
inputs, labels = self.mask_tokens(batch,tokens)
return {"input_ids": inputs, "masked_lm_labels": labels}
else:
return {"input_ids": batch, "labels": batch}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
input_ids = [example[0] for example in examples]
words_token = [example[1] for example in examples]
length_of_first = input_ids[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in input_ids)
if are_tensors_same_length:
return torch.stack(input_ids, dim=0),torch.stack(words_token, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id), \
pad_sequence(words_token, batch_first=True, padding_value=0)
def mask_tokens(self, inputs: torch.Tensor, batch_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
#probability_matrix = torch.full(labels.shape, self.mlm_probability)
#
#special_tokens_mask = [
# self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
#]
##print(special_tokens_mask)
#probability_matrix.masked_fill_(torch.ByteTensor(special_tokens_mask), value=0.0)
#if self.tokenizer._pad_token is not None:
# padding_mask = labels.eq(self.tokenizer.pad_token_id)
# probability_matrix.masked_fill_(torch.ByteTensor(padding_mask), value=0.0)
#
#masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices = []
for tokens in batch_tokens:
cand_indexes = []
prob = []
for (i, token) in enumerate(tokens):
if token == 0:
cand_indexes.append([i])
prob.append(0)
elif len(cand_indexes) >= 1 and token == 2:
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
prob.append(self.mlm_probability)
masked_indices_words = torch.bernoulli(torch.tensor(prob)).to(torch.bool)
masked_indice = []
for i,indexs in enumerate(cand_indexes):
for index in indexs:
masked_indice.append(masked_indices_words[i])
masked_indices.append(masked_indice)
masked_indices = torch.tensor(masked_indices).byte()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels