Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(preprocessor): add preprocessor feature #34

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions entity_embed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
import logging

# libgomp issue, must import n2 before torch. See: https://github.com/kakao/n2/issues/42
import n2 # noqa: F401
# import n2 # noqa: F401

from .data_modules import * # noqa: F401, F403
from .data_utils.field_config_parser import FieldConfigDictParser # noqa: F401
from .data_utils.numericalizer import default_tokenizer # noqa: F401
from .data_utils.numericalizer import (
default_tokenizer,
remove_space_digit_punc,
remove_places,
default_pre_processor,
) # noqa: F401
from .entity_embed import * # noqa: F401, F403
from .indexes import * # noqa: F401, F403

Expand Down
3 changes: 2 additions & 1 deletion entity_embed/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ def _load_model(kwargs):

model = model_cls.load_from_checkpoint(kwargs["model_save_filepath"], datamodule=None)
if kwargs["use_gpu"]:
model = model.to(torch.device("cuda"))
# model = model.to(torch.device("cuda"))
model = model.to(torch.device("mps"))
else:
model = model.to(torch.device("cpu"))
return model
Expand Down
50 changes: 47 additions & 3 deletions entity_embed/data_utils/field_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import logging
from importlib import import_module

from torchtext.vocab import Vocab
import torch
from torch import Tensor, nn
from torchtext.vocab import Vocab, Vectors, FastText
from torchtext.vocab import vocab as factory_vocab

from .numericalizer import (
AVAILABLE_VOCABS,
Expand Down Expand Up @@ -63,9 +66,16 @@ def _parse_field_config(cls, field, field_config, record_list):
tokenizer = _import_function(
field_config.get("tokenizer", "entity_embed.default_tokenizer")
)
pre_processor = _import_function(
field_config.get("pre_processor", "entity_embed.default_pre_processor")
)
multi_pre_processor = _import_function(
field_config.get("multi_pre_processor", "entity_embed.default_pre_processor")
)
alphabet = field_config.get("alphabet", DEFAULT_ALPHABET)
max_str_len = field_config.get("max_str_len")
vocab = None
vector_tensor = None

# Check if there's a key defined on the field_config,
# useful when we want to have multiple FieldConfig for the same field
Expand All @@ -92,8 +102,39 @@ def _parse_field_config(cls, field, field_config, record_list):
"field_config if you wish to use a override "
"an field name."
)
vocab = Vocab(vocab_counter)
vocab.load_vectors(vocab_type)

if vocab_type in {"tx_embeddings_large.vec", "tx_embeddings.vec"}:
vectors = Vectors(vocab_type, cache=".vector_cache")
elif vocab_type == "fasttext":
vectors = FastText("en") # might need to add standard fasttext
else:
vocab.load_vectors(vocab_type) # won't work

# adding <unk> token
unk_token = "<unk>"
vocab = factory_vocab(vocab_counter, specials=[unk_token])
# print(vocab["<unk>"]) # prints 0
# make default index same as index of unk_token
vocab.set_default_index(vocab[unk_token])
# print(vocab["probably out of vocab"]) # prints 0

# create vector tensor using tokens in vocab, order important
vectors = [vectors]
# device = torch.device("mps")
tot_dim = sum(v.dim for v in vectors)
# vector_tensor = torch.zeros(len(vocab), tot_dim)
vector_tensor = torch.Tensor(len(vocab), tot_dim)

for i, token in enumerate(vocab.get_itos()):
start_dim = 0
for v in vectors:
end_dim = start_dim + v.dim
vector_tensor[i][start_dim:end_dim] = v[token.strip()]
start_dim = end_dim
assert start_dim == tot_dim

logger.info(f"Vector tensor shape: {vector_tensor.shape}")
assert len(vector_tensor) == len(vocab)

# Compute max_str_len if necessary
if field_type in (FieldType.STRING, FieldType.MULTITOKEN) and (max_str_len is None):
Expand Down Expand Up @@ -125,9 +166,12 @@ def _parse_field_config(cls, field, field_config, record_list):
key=key,
field_type=field_type,
tokenizer=tokenizer,
pre_processor=pre_processor,
multi_pre_processor=multi_pre_processor,
alphabet=alphabet,
max_str_len=max_str_len,
vocab=vocab,
vector_tensor=vector_tensor,
n_channels=n_channels,
embed_dropout_p=embed_dropout_p,
use_attention=use_attention,
Expand Down
47 changes: 46 additions & 1 deletion entity_embed/data_utils/numericalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from enum import Enum
from typing import Callable, List

from string import punctuation
import numpy as np
import regex
import torch
from torchtext.vocab import Vocab
from flashgeotext.geotext import GeoText, GeoTextConfiguration

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +29,7 @@
"glove.6B.100d",
"glove.6B.200d",
"glove.6B.300d",
"tx_embeddings_large.vec",
]


Expand All @@ -41,10 +44,13 @@ class FieldType(Enum):
class FieldConfig:
key: str
field_type: FieldType
pre_processor: Callable[[str], List[str]]
multi_pre_processor: Callable[[str], List[str]]
tokenizer: Callable[[str], List[str]]
alphabet: List[str]
max_str_len: int
vocab: Vocab
vector_tensor: torch.Tensor
n_channels: int
embed_dropout_p: float
use_attention: bool
Expand All @@ -60,7 +66,7 @@ def __repr__(self):
repr_dict = {}
for k, v in self.__dict__.items():
if isinstance(v, Callable):
repr_dict[k] = f"{inspect.getmodule(v).__name__}.{v.__name__}"
repr_dict[k] = f"{inspect.getmodule(v).__name__}.{getattr(v, '.__name__', repr(v))}"
else:
repr_dict[k] = v
return "{cls}({attrs})".format(
Expand All @@ -77,6 +83,31 @@ def default_tokenizer(val):
return tokenizer_re.findall(val)


def remove_space_digit_punc(val):
val = "".join(c for c in val if (not c.isdigit()) and (c not in punctuation))
return val.replace(" ", "")


config = GeoTextConfiguration(**{"case_sensitive": False})
geotext = GeoText(config)


def default_pre_processor(text):
return text


def remove_places(text):
places = geotext.extract(text)
found_places = []
for i, v in places.items():
for w, x in v.items():
word = x["found_as"][0]
if word not in ["at", "com", "us", "usa"]:
found_places.append(word)
text = text.replace(word, "")
return text


class StringNumericalizer:
is_multitoken = False

Expand All @@ -85,6 +116,8 @@ def __init__(self, field, field_config):
self.alphabet = field_config.alphabet
self.max_str_len = field_config.max_str_len
self.char_to_ord = {c: i for i, c in enumerate(self.alphabet)}
self.pre_processor = field_config.pre_processor
print(f"Found pre_processor {self.pre_processor} for field {self.field}")

def _ord_encode(self, val):
ord_encoded = []
Expand All @@ -100,10 +133,15 @@ def build_tensor(self, val):
# encoded_arr is a one hot encoded bidimensional tensor
# with characters as rows and positions as columns.
# This is the shape expected by StringEmbedCNN.
# if val != self.pre_processor(val):
# print(f"{val} -> {self.pre_processor(val)} -> {self.pre_processor} -> {self.field}")
val = self.pre_processor(val)
ord_encoded_val = self._ord_encode(val)
ord_encoded_val = ord_encoded_val[: self.max_str_len] # truncate to max_str_len
encoded_arr = np.zeros((len(self.alphabet), self.max_str_len), dtype=np.float32)
if len(ord_encoded_val) > 0:
encoded_arr[ord_encoded_val, range(len(ord_encoded_val))] = 1.0

t = torch.from_numpy(encoded_arr)
return t, len(val)

Expand All @@ -127,10 +165,16 @@ class MultitokenNumericalizer:

def __init__(self, field, field_config):
self.field = field
self.field_type = field_config.field_type
self.multi_pre_processor = field_config.multi_pre_processor
self.tokenizer = field_config.tokenizer
self.string_numericalizer = StringNumericalizer(field=field, field_config=field_config)
print(f"Found multi_pre_processor {self.multi_pre_processor} for field {self.field}")

def build_tensor(self, val):
# if val != self.multi_pre_processor(val):
# print(f"{val} -> {self.multi_pre_processor(val)} -> {self.multi_pre_processor} -> {self.field}")
val = self.multi_pre_processor(val)
val_tokens = self.tokenizer(val)
t_list = []
for v in val_tokens:
Expand All @@ -149,6 +193,7 @@ class SemanticMultitokenNumericalizer(MultitokenNumericalizer):
def __init__(self, field, field_config):
self.field = field
self.tokenizer = field_config.tokenizer
self.multi_pre_processor = field_config.multi_pre_processor
self.string_numericalizer = SemanticStringNumericalizer(
field=field, field_config=field_config
)
Expand Down
6 changes: 3 additions & 3 deletions entity_embed/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
dirpath=None,
filename=None,
verbose=False,
save_last=None,
save_last=True,
save_top_k=None,
save_weights_only=False,
period=1,
Expand All @@ -53,8 +53,8 @@ def __init__(
save_top_k=save_top_k,
save_weights_only=save_weights_only,
mode=mode,
period=period,
prefix=prefix,
# period=period,
# prefix=prefix,
)
self.min_epochs = min_epochs

Expand Down
36 changes: 26 additions & 10 deletions entity_embed/entity_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from .data_utils.datasets import RecordDataset
from .early_stopping import EarlyStoppingMinEpochs, ModelCheckpointMinEpochs
from .evaluation import f1_score, pair_entity_ratio, precision_and_recall
from .indexes import ANNEntityIndex, ANNLinkageIndex

from .indexes import ANNEntityIndex # , ANNLinkageIndex
from .models import BlockerNet

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,11 +43,12 @@ def __init__(
self.record_numericalizer = record_numericalizer
for field_config in self.record_numericalizer.field_config_dict.values():
vocab = field_config.vocab
vector_tensor = field_config.vector_tensor
if vocab:
# We can assume that there's only one vocab type across the
# whole field_config_dict, so we can stop the loop once we've
# found a field_config with a vocab
valid_embedding_size = vocab.vectors.size(1)
valid_embedding_size = vector_tensor.size(1)
if valid_embedding_size != embedding_size:
raise ValueError(
f"Invalid embedding_size={embedding_size}. "
Expand All @@ -71,11 +73,11 @@ def __init__(
self.sim_threshold_list = sim_threshold_list
self.index_build_kwargs = index_build_kwargs
self.index_search_kwargs = index_search_kwargs
self._dev = "mps"

def forward(self, tensor_dict, sequence_length_dict, return_field_embeddings=False):
tensor_dict = utils.tensor_dict_to_device(tensor_dict, device=self.device)
sequence_length_dict = utils.tensor_dict_to_device(sequence_length_dict, device=self.device)

return self.blocker_net(tensor_dict, sequence_length_dict, return_field_embeddings)

def _warn_if_empty_indices_tuple(self, indices_tuple, batch_idx):
Expand All @@ -99,7 +101,7 @@ def training_step(self, batch, batch_idx):
self.log("train_loss", loss)
return loss

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx=None):
self.blocker_net.fix_pool_weights()
self.log_dict(
{
Expand Down Expand Up @@ -150,24 +152,27 @@ def fit(
min_epochs=5,
max_epochs=100,
check_val_every_n_epoch=1,
use_early_stop=False,
early_stop_monitor="valid_recall_at_0.3",
early_stop_min_delta=0.0,
early_stop_patience=20,
early_stop_mode=None,
early_stop_verbose=True,
model_save_top_k=1,
model_save_dir=None,
model_save_filename=None,
model_save_verbose=False,
tb_save_dir=None,
tb_name=None,
use_gpu=True,
use_gpu=False,
accelerator="cpu",
ckpt_path=None,
):
if early_stop_mode is None:
if "pair_entity_ratio_at" in early_stop_monitor:
early_stop_mode = "min"
else:
early_stop_mode = "max"

early_stop_callback = EarlyStoppingMinEpochs(
min_epochs=min_epochs,
monitor=early_stop_monitor,
Expand All @@ -176,24 +181,31 @@ def fit(
mode=early_stop_mode,
verbose=early_stop_verbose,
)
callbacks = []
if use_early_stop:
callbacks.append(early_stop_callback)
print("Using early stopping callback...")
checkpoint_callback = ModelCheckpointMinEpochs(
min_epochs=min_epochs,
monitor=early_stop_monitor,
save_top_k=model_save_top_k,
mode=early_stop_mode,
dirpath=model_save_dir,
filename=model_save_filename,
verbose=model_save_verbose,
)
callbacks.append(checkpoint_callback)
trainer_args = {
"min_epochs": min_epochs,
"max_epochs": max_epochs,
"check_val_every_n_epoch": check_val_every_n_epoch,
"callbacks": [early_stop_callback, checkpoint_callback],
"reload_dataloaders_every_epoch": True, # for shuffling ClusterDataset every epoch
"callbacks": callbacks,
"reload_dataloaders_every_n_epochs": 10, # for shuffling ClusterDataset every epoch
}
if use_gpu:
trainer_args["gpus"] = 1

if accelerator:
trainer_args["accelerator"] = accelerator
if tb_name and tb_save_dir:
trainer_args["logger"] = TensorBoardLogger(
tb_save_dir,
Expand All @@ -204,8 +216,12 @@ def fit(
'Please provide both "tb_name" and "tb_save_dir" to enable '
"TensorBoardLogger or omit both to disable it"
)
fit_args = {}
if ckpt_path:
fit_args["ckpt_path"] = ckpt_path
trainer = pl.Trainer(**trainer_args)
trainer.fit(self, datamodule)

trainer.fit(self, datamodule, **fit_args)

logger.info(
"Loading the best validation model from "
Expand Down
Loading