-
Notifications
You must be signed in to change notification settings - Fork 198
/
data_helper.py
66 lines (57 loc) · 2.18 KB
/
data_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import re
import logging
import numpy as np
import pandas as pd
from collections import Counter
def clean_str(s):
"""Clean sentence"""
s = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", s)
s = re.sub(r"\'s", " \'s", s)
s = re.sub(r"\'ve", " \'ve", s)
s = re.sub(r"n\'t", " n\'t", s)
s = re.sub(r"\'re", " \'re", s)
s = re.sub(r"\'d", " \'d", s)
s = re.sub(r"\'ll", " \'ll", s)
s = re.sub(r",", " , ", s)
s = re.sub(r"!", " ! ", s)
s = re.sub(r"\(", " \( ", s)
s = re.sub(r"\)", " \) ", s)
s = re.sub(r"\?", " \? ", s)
s = re.sub(r"\s{2,}", " ", s)
s = re.sub(r'\S*(x{2,}|X{2,})\S*',"xxx", s)
s = re.sub(r'[^\x00-\x7F]+', "", s)
return s.strip().lower()
def load_data_and_labels(filename):
"""Load sentences and labels"""
df = pd.read_csv(filename, compression='zip', dtype={'consumer_complaint_narrative': object})
selected = ['product', 'consumer_complaint_narrative']
non_selected = list(set(df.columns) - set(selected))
df = df.drop(non_selected, axis=1) # Drop non selected columns
df = df.dropna(axis=0, how='any', subset=selected) # Drop null rows
df = df.reindex(np.random.permutation(df.index)) # Shuffle the dataframe
# Map the actual labels to one hot labels
labels = sorted(list(set(df[selected[0]].tolist())))
one_hot = np.zeros((len(labels), len(labels)), int)
np.fill_diagonal(one_hot, 1)
label_dict = dict(zip(labels, one_hot))
x_raw = df[selected[1]].apply(lambda x: clean_str(x)).tolist()
y_raw = df[selected[0]].apply(lambda y: label_dict[y]).tolist()
return x_raw, y_raw, df, labels
def batch_iter(data, batch_size, num_epochs, shuffle=True):
"""Iterate the data batch by batch"""
data = np.array(data)
data_size = len(data)
num_batches_per_epoch = int(data_size / batch_size) + 1
for epoch in range(num_epochs):
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
else:
shuffled_data = data
for batch_num in range(num_batches_per_epoch):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_index:end_index]
if __name__ == '__main__':
input_file = './data/consumer_complaints.csv.zip'
load_data_and_labels(input_file)