diff --git a/entity_embed/__init__.py b/entity_embed/__init__.py index 0ebd1b7..4a5f21d 100644 --- a/entity_embed/__init__.py +++ b/entity_embed/__init__.py @@ -2,11 +2,16 @@ import logging # libgomp issue, must import n2 before torch. See: https://github.com/kakao/n2/issues/42 -import n2 # noqa: F401 +# import n2 # noqa: F401 from .data_modules import * # noqa: F401, F403 from .data_utils.field_config_parser import FieldConfigDictParser # noqa: F401 -from .data_utils.numericalizer import default_tokenizer # noqa: F401 +from .data_utils.numericalizer import ( + default_tokenizer, + remove_space_digit_punc, + remove_places, + default_pre_processor, +) # noqa: F401 from .entity_embed import * # noqa: F401, F403 from .indexes import * # noqa: F401, F403 diff --git a/entity_embed/cli.py b/entity_embed/cli.py index 2aebb02..b2bff4e 100644 --- a/entity_embed/cli.py +++ b/entity_embed/cli.py @@ -369,7 +369,8 @@ def _load_model(kwargs): model = model_cls.load_from_checkpoint(kwargs["model_save_filepath"], datamodule=None) if kwargs["use_gpu"]: - model = model.to(torch.device("cuda")) + # model = model.to(torch.device("cuda")) + model = model.to(torch.device("mps")) else: model = model.to(torch.device("cpu")) return model diff --git a/entity_embed/data_utils/field_config_parser.py b/entity_embed/data_utils/field_config_parser.py index e548235..b7a6b54 100644 --- a/entity_embed/data_utils/field_config_parser.py +++ b/entity_embed/data_utils/field_config_parser.py @@ -2,7 +2,10 @@ import logging from importlib import import_module -from torchtext.vocab import Vocab +import torch +from torch import Tensor, nn +from torchtext.vocab import Vocab, Vectors, FastText +from torchtext.vocab import vocab as factory_vocab from .numericalizer import ( AVAILABLE_VOCABS, @@ -63,9 +66,16 @@ def _parse_field_config(cls, field, field_config, record_list): tokenizer = _import_function( field_config.get("tokenizer", "entity_embed.default_tokenizer") ) + pre_processor = _import_function( + field_config.get("pre_processor", "entity_embed.default_pre_processor") + ) + multi_pre_processor = _import_function( + field_config.get("multi_pre_processor", "entity_embed.default_pre_processor") + ) alphabet = field_config.get("alphabet", DEFAULT_ALPHABET) max_str_len = field_config.get("max_str_len") vocab = None + vector_tensor = None # Check if there's a key defined on the field_config, # useful when we want to have multiple FieldConfig for the same field @@ -92,8 +102,39 @@ def _parse_field_config(cls, field, field_config, record_list): "field_config if you wish to use a override " "an field name." ) - vocab = Vocab(vocab_counter) - vocab.load_vectors(vocab_type) + + if vocab_type in {"tx_embeddings_large.vec", "tx_embeddings.vec"}: + vectors = Vectors(vocab_type, cache=".vector_cache") + elif vocab_type == "fasttext": + vectors = FastText("en") # might need to add standard fasttext + else: + vocab.load_vectors(vocab_type) # won't work + + # adding token + unk_token = "" + vocab = factory_vocab(vocab_counter, specials=[unk_token]) + # print(vocab[""]) # prints 0 + # make default index same as index of unk_token + vocab.set_default_index(vocab[unk_token]) + # print(vocab["probably out of vocab"]) # prints 0 + + # create vector tensor using tokens in vocab, order important + vectors = [vectors] + # device = torch.device("mps") + tot_dim = sum(v.dim for v in vectors) + # vector_tensor = torch.zeros(len(vocab), tot_dim) + vector_tensor = torch.Tensor(len(vocab), tot_dim) + + for i, token in enumerate(vocab.get_itos()): + start_dim = 0 + for v in vectors: + end_dim = start_dim + v.dim + vector_tensor[i][start_dim:end_dim] = v[token.strip()] + start_dim = end_dim + assert start_dim == tot_dim + + logger.info(f"Vector tensor shape: {vector_tensor.shape}") + assert len(vector_tensor) == len(vocab) # Compute max_str_len if necessary if field_type in (FieldType.STRING, FieldType.MULTITOKEN) and (max_str_len is None): @@ -125,9 +166,12 @@ def _parse_field_config(cls, field, field_config, record_list): key=key, field_type=field_type, tokenizer=tokenizer, + pre_processor=pre_processor, + multi_pre_processor=multi_pre_processor, alphabet=alphabet, max_str_len=max_str_len, vocab=vocab, + vector_tensor=vector_tensor, n_channels=n_channels, embed_dropout_p=embed_dropout_p, use_attention=use_attention, diff --git a/entity_embed/data_utils/numericalizer.py b/entity_embed/data_utils/numericalizer.py index dbdbdc8..14246d8 100644 --- a/entity_embed/data_utils/numericalizer.py +++ b/entity_embed/data_utils/numericalizer.py @@ -4,10 +4,12 @@ from enum import Enum from typing import Callable, List +from string import punctuation import numpy as np import regex import torch from torchtext.vocab import Vocab +from flashgeotext.geotext import GeoText, GeoTextConfiguration logger = logging.getLogger(__name__) @@ -27,6 +29,7 @@ "glove.6B.100d", "glove.6B.200d", "glove.6B.300d", + "tx_embeddings_large.vec", ] @@ -41,10 +44,13 @@ class FieldType(Enum): class FieldConfig: key: str field_type: FieldType + pre_processor: Callable[[str], List[str]] + multi_pre_processor: Callable[[str], List[str]] tokenizer: Callable[[str], List[str]] alphabet: List[str] max_str_len: int vocab: Vocab + vector_tensor: torch.Tensor n_channels: int embed_dropout_p: float use_attention: bool @@ -60,7 +66,7 @@ def __repr__(self): repr_dict = {} for k, v in self.__dict__.items(): if isinstance(v, Callable): - repr_dict[k] = f"{inspect.getmodule(v).__name__}.{v.__name__}" + repr_dict[k] = f"{inspect.getmodule(v).__name__}.{getattr(v, '.__name__', repr(v))}" else: repr_dict[k] = v return "{cls}({attrs})".format( @@ -77,6 +83,31 @@ def default_tokenizer(val): return tokenizer_re.findall(val) +def remove_space_digit_punc(val): + val = "".join(c for c in val if (not c.isdigit()) and (c not in punctuation)) + return val.replace(" ", "") + + +config = GeoTextConfiguration(**{"case_sensitive": False}) +geotext = GeoText(config) + + +def default_pre_processor(text): + return text + + +def remove_places(text): + places = geotext.extract(text) + found_places = [] + for i, v in places.items(): + for w, x in v.items(): + word = x["found_as"][0] + if word not in ["at", "com", "us", "usa"]: + found_places.append(word) + text = text.replace(word, "") + return text + + class StringNumericalizer: is_multitoken = False @@ -85,6 +116,8 @@ def __init__(self, field, field_config): self.alphabet = field_config.alphabet self.max_str_len = field_config.max_str_len self.char_to_ord = {c: i for i, c in enumerate(self.alphabet)} + self.pre_processor = field_config.pre_processor + print(f"Found pre_processor {self.pre_processor} for field {self.field}") def _ord_encode(self, val): ord_encoded = [] @@ -100,10 +133,15 @@ def build_tensor(self, val): # encoded_arr is a one hot encoded bidimensional tensor # with characters as rows and positions as columns. # This is the shape expected by StringEmbedCNN. + # if val != self.pre_processor(val): + # print(f"{val} -> {self.pre_processor(val)} -> {self.pre_processor} -> {self.field}") + val = self.pre_processor(val) ord_encoded_val = self._ord_encode(val) + ord_encoded_val = ord_encoded_val[: self.max_str_len] # truncate to max_str_len encoded_arr = np.zeros((len(self.alphabet), self.max_str_len), dtype=np.float32) if len(ord_encoded_val) > 0: encoded_arr[ord_encoded_val, range(len(ord_encoded_val))] = 1.0 + t = torch.from_numpy(encoded_arr) return t, len(val) @@ -127,10 +165,16 @@ class MultitokenNumericalizer: def __init__(self, field, field_config): self.field = field + self.field_type = field_config.field_type + self.multi_pre_processor = field_config.multi_pre_processor self.tokenizer = field_config.tokenizer self.string_numericalizer = StringNumericalizer(field=field, field_config=field_config) + print(f"Found multi_pre_processor {self.multi_pre_processor} for field {self.field}") def build_tensor(self, val): + # if val != self.multi_pre_processor(val): + # print(f"{val} -> {self.multi_pre_processor(val)} -> {self.multi_pre_processor} -> {self.field}") + val = self.multi_pre_processor(val) val_tokens = self.tokenizer(val) t_list = [] for v in val_tokens: @@ -149,6 +193,7 @@ class SemanticMultitokenNumericalizer(MultitokenNumericalizer): def __init__(self, field, field_config): self.field = field self.tokenizer = field_config.tokenizer + self.multi_pre_processor = field_config.multi_pre_processor self.string_numericalizer = SemanticStringNumericalizer( field=field, field_config=field_config ) diff --git a/entity_embed/early_stopping.py b/entity_embed/early_stopping.py index cf80e49..5edd788 100644 --- a/entity_embed/early_stopping.py +++ b/entity_embed/early_stopping.py @@ -38,7 +38,7 @@ def __init__( dirpath=None, filename=None, verbose=False, - save_last=None, + save_last=True, save_top_k=None, save_weights_only=False, period=1, @@ -53,8 +53,8 @@ def __init__( save_top_k=save_top_k, save_weights_only=save_weights_only, mode=mode, - period=period, - prefix=prefix, + # period=period, + # prefix=prefix, ) self.min_epochs = min_epochs diff --git a/entity_embed/entity_embed.py b/entity_embed/entity_embed.py index c4b8d95..c613103 100644 --- a/entity_embed/entity_embed.py +++ b/entity_embed/entity_embed.py @@ -12,7 +12,8 @@ from .data_utils.datasets import RecordDataset from .early_stopping import EarlyStoppingMinEpochs, ModelCheckpointMinEpochs from .evaluation import f1_score, pair_entity_ratio, precision_and_recall -from .indexes import ANNEntityIndex, ANNLinkageIndex + +from .indexes import ANNEntityIndex # , ANNLinkageIndex from .models import BlockerNet logger = logging.getLogger(__name__) @@ -42,11 +43,12 @@ def __init__( self.record_numericalizer = record_numericalizer for field_config in self.record_numericalizer.field_config_dict.values(): vocab = field_config.vocab + vector_tensor = field_config.vector_tensor if vocab: # We can assume that there's only one vocab type across the # whole field_config_dict, so we can stop the loop once we've # found a field_config with a vocab - valid_embedding_size = vocab.vectors.size(1) + valid_embedding_size = vector_tensor.size(1) if valid_embedding_size != embedding_size: raise ValueError( f"Invalid embedding_size={embedding_size}. " @@ -71,11 +73,11 @@ def __init__( self.sim_threshold_list = sim_threshold_list self.index_build_kwargs = index_build_kwargs self.index_search_kwargs = index_search_kwargs + self._dev = "mps" def forward(self, tensor_dict, sequence_length_dict, return_field_embeddings=False): tensor_dict = utils.tensor_dict_to_device(tensor_dict, device=self.device) sequence_length_dict = utils.tensor_dict_to_device(sequence_length_dict, device=self.device) - return self.blocker_net(tensor_dict, sequence_length_dict, return_field_embeddings) def _warn_if_empty_indices_tuple(self, indices_tuple, batch_idx): @@ -99,7 +101,7 @@ def training_step(self, batch, batch_idx): self.log("train_loss", loss) return loss - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx=None): self.blocker_net.fix_pool_weights() self.log_dict( { @@ -150,6 +152,7 @@ def fit( min_epochs=5, max_epochs=100, check_val_every_n_epoch=1, + use_early_stop=False, early_stop_monitor="valid_recall_at_0.3", early_stop_min_delta=0.0, early_stop_patience=20, @@ -157,17 +160,19 @@ def fit( early_stop_verbose=True, model_save_top_k=1, model_save_dir=None, + model_save_filename=None, model_save_verbose=False, tb_save_dir=None, tb_name=None, - use_gpu=True, + use_gpu=False, + accelerator="cpu", + ckpt_path=None, ): if early_stop_mode is None: if "pair_entity_ratio_at" in early_stop_monitor: early_stop_mode = "min" else: early_stop_mode = "max" - early_stop_callback = EarlyStoppingMinEpochs( min_epochs=min_epochs, monitor=early_stop_monitor, @@ -176,24 +181,31 @@ def fit( mode=early_stop_mode, verbose=early_stop_verbose, ) + callbacks = [] + if use_early_stop: + callbacks.append(early_stop_callback) + print("Using early stopping callback...") checkpoint_callback = ModelCheckpointMinEpochs( min_epochs=min_epochs, monitor=early_stop_monitor, save_top_k=model_save_top_k, mode=early_stop_mode, dirpath=model_save_dir, + filename=model_save_filename, verbose=model_save_verbose, ) + callbacks.append(checkpoint_callback) trainer_args = { "min_epochs": min_epochs, "max_epochs": max_epochs, "check_val_every_n_epoch": check_val_every_n_epoch, - "callbacks": [early_stop_callback, checkpoint_callback], - "reload_dataloaders_every_epoch": True, # for shuffling ClusterDataset every epoch + "callbacks": callbacks, + "reload_dataloaders_every_n_epochs": 10, # for shuffling ClusterDataset every epoch } if use_gpu: trainer_args["gpus"] = 1 - + if accelerator: + trainer_args["accelerator"] = accelerator if tb_name and tb_save_dir: trainer_args["logger"] = TensorBoardLogger( tb_save_dir, @@ -204,8 +216,12 @@ def fit( 'Please provide both "tb_name" and "tb_save_dir" to enable ' "TensorBoardLogger or omit both to disable it" ) + fit_args = {} + if ckpt_path: + fit_args["ckpt_path"] = ckpt_path trainer = pl.Trainer(**trainer_args) - trainer.fit(self, datamodule) + + trainer.fit(self, datamodule, **fit_args) logger.info( "Loading the best validation model from " diff --git a/entity_embed/evaluation.py b/entity_embed/evaluation.py index 87ed12d..4cf767d 100755 --- a/entity_embed/evaluation.py +++ b/entity_embed/evaluation.py @@ -1,5 +1,12 @@ import csv import json +import random +from .indexes import ANNEntityIndex +from .data_utils import utils +import pandas as pd +import logging + +logger = logging.getLogger(__name__) def pair_entity_ratio(found_pair_set_len, entity_count): @@ -53,3 +60,59 @@ def evaluate_output_json( f1_score(precision, recall), pair_entity_ratio(len(found_pair_set), record_count), ) + + +class EmbeddingEvaluator: + def __init__(self, record_dict, vector_dict, cluster_field="cluster_id"): + self.record_dict = record_dict + self.cluster_field = cluster_field + embedding_size = len(next(iter(vector_dict.values()))) + logging.info("Building index...") + self.ann_index = ANNEntityIndex(embedding_size) + self.ann_index.insert_vector_dict(vector_dict) + self.ann_index.build() + logging.info("Index built!") + self.cluster_dict = utils.record_dict_to_cluster_dict(self.record_dict, self.cluster_field) + self.pos_pair_set = utils.cluster_dict_to_id_pairs(self.cluster_dict) + + def evaluate(self, k, sim_thresholds, query_ids=None, get_missing_pair_set=False): + """ + params: + k: int: number of nearest neighbours to retrieve + sim_thresholds: list of floats in the range [0,1]: + query_ids: list or set of ids that must be keys in self.vector_dict and self.record_dict. Indicates + which ids to find pairs for. If None, use all record ids as query ids + + returns: pandas DataFrame of results, with one row for each threshold + """ + if query_ids is None: + logging.info(f"Using all {len(self.record_dict)} records to query for neighbours") + pos_pair_subset = self.pos_pair_set + else: + query_ids = set(query_ids) + logging.info(f"Using subset of {len(query_ids)} query IDs") + pos_pair_subset = { + pair for pair in self.pos_pair_set if pair[0] in query_ids or pair[1] in query_ids + } + results = [] + for sim_threshold in sim_thresholds: + found_pair_set = self.ann_index.search_pairs( + k, sim_threshold, query_id_subset=query_ids + ) + precision, recall = precision_and_recall(found_pair_set, pos_pair_subset) + results.append( + (sim_threshold, precision, recall, f1_score(precision, recall), len(found_pair_set)) + ) + if get_missing_pair_set & (sim_threshold == min(sim_thresholds)): + self.missing_pair_set = pos_pair_subset - found_pair_set + id_to_name_map = {k: v["merchant_name"] for k, v in self.record_dict.items()} + self.missing_pair_name_set = set( + map( + lambda x: (id_to_name_map[x[0]], id_to_name_map[x[1]]), + self.missing_pair_set, + ) + ) + + return pd.DataFrame( + results, columns=["threshold", "precision", "recall", "f1_score", "n_pairs_found"] + ) diff --git a/entity_embed/indexes.py b/entity_embed/indexes.py index c59151a..9363152 100644 --- a/entity_embed/indexes.py +++ b/entity_embed/indexes.py @@ -1,6 +1,8 @@ import logging +import faiss +import numpy as np -from n2 import HnswIndex +# from n2 import HnswIndex from .helpers import build_index_build_kwargs, build_index_search_kwargs @@ -9,13 +11,20 @@ class ANNEntityIndex: def __init__(self, embedding_size): - self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") + self.approx_knn_index = faiss.index_factory( + embedding_size, "Flat", faiss.METRIC_INNER_PRODUCT + ) + # self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") self.vector_idx_to_id = None + self.normalized_vector_array = None self.is_built = False + self.embedding_size = embedding_size def insert_vector_dict(self, vector_dict): - for vector in vector_dict.values(): - self.approx_knn_index.add_data(vector) + vector_array = np.array(list(vector_dict.values())) + l2_norm = np.linalg.norm(vector_array, ord=2, axis=1).reshape(vector_array.shape[0], 1) + self.normalized_vector_array = vector_array / l2_norm + self.approx_knn_index.add(self.normalized_vector_array) self.vector_idx_to_id = dict(enumerate(vector_dict.keys())) def build( @@ -25,116 +34,173 @@ def build( if self.vector_idx_to_id is None: raise ValueError("Please call insert_vector_dict first") - actual_index_build_kwargs = build_index_build_kwargs(index_build_kwargs) - self.approx_knn_index.build(**actual_index_build_kwargs) + # actual_index_build_kwargs = build_index_build_kwargs(index_build_kwargs) + # self.approx_knn_index.build(**actual_index_build_kwargs) self.is_built = True + # faiss.write_index(self.approx_knn_index, "vector.index") - def search_pairs(self, k, sim_threshold, index_search_kwargs=None): + def search_pairs(self, k, sim_threshold, index_search_kwargs=None, query_id_subset=None): if not self.is_built: raise ValueError("Please call build first") if sim_threshold > 1 or sim_threshold < 0: raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") logger.debug("Searching on approx_knn_index...") - distance_threshold = 1 - sim_threshold index_search_kwargs = build_index_search_kwargs(index_search_kwargs) - neighbor_and_distance_list_of_list = self.approx_knn_index.batch_search_by_ids( - item_ids=self.vector_idx_to_id.keys(), - k=k, - include_distances=True, - **index_search_kwargs, - ) - - logger.debug("Search on approx_knn_index done, building found_pair_set now...") found_pair_set = set() - for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): - left_id = self.vector_idx_to_id[i] - for j, distance in neighbor_distance_list: - if i != j and distance <= distance_threshold: - right_id = self.vector_idx_to_id[j] - # must use sorted to always have smaller id on left of pair tuple - pair = tuple(sorted([left_id, right_id])) - found_pair_set.add(pair) + for i, left_id in self.vector_idx_to_id.items(): + if query_id_subset is None or left_id in query_id_subset: + vector = self.normalized_vector_array[[i], :] + similarities, neighbours = self.approx_knn_index.search(vector, k=k) + if all(similarities[0] >= sim_threshold) & (sim_threshold > 0.4): + print( + f"Found pair similarities for k = {k} are all higher than threshold {sim_threshold}" + ) + for similarity, j in zip(similarities[0], neighbours[0]): + if i != j and similarity >= sim_threshold: + right_id = self.vector_idx_to_id[j] + # must use sorted to always have smaller id on left of pair tuple + pair = tuple(sorted([left_id, right_id])) + found_pair_set.add(pair) logger.debug( - f"Building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." + f"Search on approx_knn_index and building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." ) return found_pair_set -class ANNLinkageIndex: - def __init__(self, embedding_size): - self.left_index = ANNEntityIndex(embedding_size) - self.right_index = ANNEntityIndex(embedding_size) - - def insert_vector_dict(self, left_vector_dict, right_vector_dict): - self.left_index.insert_vector_dict(vector_dict=left_vector_dict) - self.right_index.insert_vector_dict(vector_dict=right_vector_dict) - - def build( - self, - index_build_kwargs=None, - ): - self.left_index.build(index_build_kwargs=index_build_kwargs) - self.right_index.build(index_build_kwargs=index_build_kwargs) - - def search_pairs( - self, - k, - sim_threshold, - left_vector_dict, - right_vector_dict, - left_source, - index_search_kwargs=None, - ): - if not self.left_index.is_built or not self.right_index.is_built: - raise ValueError("Please call build first") - if sim_threshold > 1 or sim_threshold < 0: - raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") - - index_search_kwargs = build_index_search_kwargs(index_search_kwargs) - distance_threshold = 1 - sim_threshold - all_pair_set = set() - - for dataset_name, index, vector_dict, other_index in [ - (left_source, self.left_index, right_vector_dict, self.right_index), - (None, self.right_index, left_vector_dict, self.left_index), - ]: - logger.debug(f"Searching on approx_knn_index of dataset_name={dataset_name}...") - - neighbor_and_distance_list_of_list = index.approx_knn_index.batch_search_by_vectors( - vs=vector_dict.values(), k=k, include_distances=True, **index_search_kwargs - ) - - logger.debug( - f"Search on approx_knn_index of dataset_name={dataset_name}... done, " - "filling all_pair_set now..." - ) - - for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): - other_id = other_index.vector_idx_to_id[i] - for j, distance in neighbor_distance_list: - if distance <= distance_threshold: # do NOT check for i != j here - id_ = index.vector_idx_to_id[j] - if dataset_name and dataset_name == left_source: - left_id, right_id = (id_, other_id) - else: - left_id, right_id = (other_id, id_) - pair = ( - left_id, - right_id, - ) # do NOT use sorted here, figure out from datasets - all_pair_set.add(pair) - - logger.debug(f"Filling all_pair_set with dataset_name={dataset_name} done.") - - logger.debug( - "All searches done, all_pair_set filled. " - f"Found len(all_pair_set)={len(all_pair_set)} pairs." - ) - - return all_pair_set +# class ANNEntityIndex: +# def __init__(self, embedding_size): +# self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") +# self.vector_idx_to_id = None +# self.is_built = False + +# def insert_vector_dict(self, vector_dict): +# for vector in vector_dict.values(): +# self.approx_knn_index.add_data(vector) +# self.vector_idx_to_id = dict(enumerate(vector_dict.keys())) + +# def build( +# self, +# index_build_kwargs=None, +# ): +# if self.vector_idx_to_id is None: +# raise ValueError("Please call insert_vector_dict first") + +# actual_index_build_kwargs = build_index_build_kwargs(index_build_kwargs) +# self.approx_knn_index.build(**actual_index_build_kwargs) +# self.is_built = True + +# def search_pairs(self, k, sim_threshold, index_search_kwargs=None): +# if not self.is_built: +# raise ValueError("Please call build first") +# if sim_threshold > 1 or sim_threshold < 0: +# raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") + +# logger.debug("Searching on approx_knn_index...") + +# distance_threshold = 1 - sim_threshold + +# index_search_kwargs = build_index_search_kwargs(index_search_kwargs) +# neighbor_and_distance_list_of_list = self.approx_knn_index.batch_search_by_ids( +# item_ids=self.vector_idx_to_id.keys(), +# k=k, +# include_distances=True, +# **index_search_kwargs, +# ) + +# logger.debug("Search on approx_knn_index done, building found_pair_set now...") + +# found_pair_set = set() +# for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): +# left_id = self.vector_idx_to_id[i] +# for j, distance in neighbor_distance_list: +# if i != j and distance <= distance_threshold: +# right_id = self.vector_idx_to_id[j] +# # must use sorted to always have smaller id on left of pair tuple +# pair = tuple(sorted([left_id, right_id])) +# found_pair_set.add(pair) + +# logger.debug( +# f"Building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." +# ) + +# return found_pair_set + + +# class ANNLinkageIndex: +# def __init__(self, embedding_size): +# self.left_index = ANNEntityIndex(embedding_size) +# self.right_index = ANNEntityIndex(embedding_size) + +# def insert_vector_dict(self, left_vector_dict, right_vector_dict): +# self.left_index.insert_vector_dict(vector_dict=left_vector_dict) +# self.right_index.insert_vector_dict(vector_dict=right_vector_dict) + +# def build( +# self, +# index_build_kwargs=None, +# ): +# self.left_index.build(index_build_kwargs=index_build_kwargs) +# self.right_index.build(index_build_kwargs=index_build_kwargs) + +# def search_pairs( +# self, +# k, +# sim_threshold, +# left_vector_dict, +# right_vector_dict, +# left_source, +# index_search_kwargs=None, +# ): +# if not self.left_index.is_built or not self.right_index.is_built: +# raise ValueError("Please call build first") +# if sim_threshold > 1 or sim_threshold < 0: +# raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") + +# index_search_kwargs = build_index_search_kwargs(index_search_kwargs) +# distance_threshold = 1 - sim_threshold +# all_pair_set = set() + +# for dataset_name, index, vector_dict, other_index in [ +# (left_source, self.left_index, right_vector_dict, self.right_index), +# (None, self.right_index, left_vector_dict, self.left_index), +# ]: +# logger.debug(f"Searching on approx_knn_index of dataset_name={dataset_name}...") + +# neighbor_and_distance_list_of_list = index.approx_knn_index.batch_search_by_vectors( +# vs=vector_dict.values(), k=k, include_distances=True, **index_search_kwargs +# ) + +# logger.debug( +# f"Search on approx_knn_index of dataset_name={dataset_name}... done, " +# "filling all_pair_set now..." +# ) + +# for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): +# other_id = other_index.vector_idx_to_id[i] +# for j, distance in neighbor_distance_list: +# if distance <= distance_threshold: # do NOT check for i != j here +# id_ = index.vector_idx_to_id[j] +# if dataset_name and dataset_name == left_source: +# left_id, right_id = (id_, other_id) +# else: +# left_id, right_id = (other_id, id_) +# pair = ( +# left_id, +# right_id, +# ) # do NOT use sorted here, figure out from datasets +# all_pair_set.add(pair) + +# logger.debug(f"Filling all_pair_set with dataset_name={dataset_name} done.") + +# logger.debug( +# "All searches done, all_pair_set filled. " +# f"Found len(all_pair_set)={len(all_pair_set)} pairs." +# ) + +# return all_pair_set diff --git a/entity_embed/models.py b/entity_embed/models.py index 68f8159..e69b5cb 100644 --- a/entity_embed/models.py +++ b/entity_embed/models.py @@ -58,7 +58,7 @@ def __init__(self, field_config, embedding_size): self.embedding_size = embedding_size self.dense_net = nn.Sequential( - nn.Embedding.from_pretrained(field_config.vocab.vectors), + nn.Embedding.from_pretrained(field_config.vector_tensor), nn.Dropout(p=field_config.embed_dropout_p), ) @@ -151,6 +151,7 @@ def forward(self, x, sequence_lengths, **kwargs): x_list = x.unbind(dim=1) x_list = [self.embed_net(x) for x in x_list] x = torch.stack(x_list, dim=1) + print(x.device) # Compute a mask for the attention on the padded sequences # See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5 diff --git a/hannah_requirements.txt b/hannah_requirements.txt new file mode 100644 index 0000000..8437c66 --- /dev/null +++ b/hannah_requirements.txt @@ -0,0 +1,15 @@ +click==8.0.4 +faiss==1.5.3 +mock==4.0.3 +more_itertools==9.0.0 +numpy==1.21.5 +ordered_set==4.1.0 +pytest==7.1.1 +pytorch_lightning==1.7.7 +pytorch_metric_learning==1.6.2 +regex==2022.3.15 +setuptools==61.2.0 +sphinx_rtd_theme==1.0.0 +torch==1.13.0 +torchtext==0.14.0 +tqdm==4.64.0 diff --git a/requirements.txt b/requirements.txt index 199a67b..4384bba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,14 @@ click==7.1.2,<8.0 more-itertools>=8.6.0,<9.0 -n2>=0.1.7,<1.2 numpy>=1.19.0 ordered-set>=4.0.2 -pytorch_lightning>=1.1.6,<1.3 -pytorch-metric-learning>=0.9.98,<1.0 regex>=2020.11.13 -torch>=1.7.1,<1.9 -torchtext>=0.8,<0.10 -torchvision>=0.8.2<0.10 +pytorch-lightning==1.7.7 +pytorch-metric-learning>=0.9.99 +torch==1.12.1 +torchmetrics>=0.10.1 +torchtext==0.13.1 tqdm>=4.53.0 + +# conda install grpcio +# pip install faiss-cpu diff --git a/tests/test_data_utils_helpers.py b/tests/test_data_utils_helpers.py index 830717f..92ffe85 100644 --- a/tests/test_data_utils_helpers.py +++ b/tests/test_data_utils_helpers.py @@ -3,7 +3,8 @@ import tempfile import mock -import n2 # noqa: F401 + +# import n2 # noqa: F401 import pytest from entity_embed.data_utils.field_config_parser import FieldConfigDictParser from entity_embed.data_utils.numericalizer import FieldConfig, FieldType, RecordNumericalizer