From a2cf02b72bb4c382902a861d4366c5f075128f9c Mon Sep 17 00:00:00 2001 From: ashahzada Date: Mon, 24 Oct 2022 12:35:37 +0100 Subject: [PATCH 01/17] notebooks: added merchant deduplication notebook --- notebooks/Deduplication-merchants.ipynb | 2320 +++++++++++++++++++++++ 1 file changed, 2320 insertions(+) create mode 100755 notebooks/Deduplication-merchants.ipynb diff --git a/notebooks/Deduplication-merchants.ipynb b/notebooks/Deduplication-merchants.ipynb new file mode 100755 index 0000000..996a043 --- /dev/null +++ b/notebooks/Deduplication-merchants.ipynb @@ -0,0 +1,2320 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Deduplication Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Boilerplate" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from importlib import reload\n", + "import logging\n", + "import torch\n", + "import json\n", + "reload(logging)\n", + "logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.insert(0, '..')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import entity_embed_local\n", + "from entity_embed_local import EntityEmbedLocal\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "torch.set_num_threads(1)\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "random_seed = 40\n", + "torch.manual_seed(random_seed)\n", + "np.random.seed(random_seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's download the CSV dataset to a temporary directory:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv('data/alias_data.csv',names=['merchant_name','merchant_id','plaid_merchant','plaid_category','avg_tx_amount','count'])\n", + "dataset = data.groupby('merchant_id').head(200)\n", + "dataset.avg_tx_amount = dataset.avg_tx_amount.round()\n", + "dataset = dataset.astype(str)\n", + "# dataset = dataset[dataset.merchant_id.isin(dataset.merchant_id.unique()[:50000])]\n", + "dataset['cluster_id'] = pd.factorize(dataset['merchant_id'].tolist())[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "mapping = pd.Series(dataset.merchant_name.values,index=dataset.cluster_id).to_dict()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "record_dict = {}\n", + "cluster_field = 'cluster_id'\n", + "\n", + "for current_record_id, record in dataset[['cluster_id','merchant_name','plaid_merchant','plaid_category','avg_tx_amount']].iterrows():\n", + " record['id'] = current_record_id\n", + " record[cluster_field] = int(record[cluster_field]) # convert cluster_field to int\n", + " record_dict[current_record_id] = record.to_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "How many clusters this dataset has?" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "35104" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cluster_total = len(set(record[cluster_field] for record in record_dict.values()))\n", + "cluster_total" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "From all clusters, we'll use only 50% for training, and other 15% for validation to test how well we can generalize:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "12:55:26 INFO:Singleton cluster sizes (train, valid, test):(18649, 4662, 7772)\n", + "12:55:26 INFO:Plural cluster sizes (train, valid, test):(2412, 603, 1006)\n" + ] + } + ], + "source": [ + "from data_utils import utils\n", + "\n", + "train_record_dict, valid_record_dict, test_record_dict = utils.split_record_dict_on_clusters(\n", + " record_dict=record_dict,\n", + " cluster_field=cluster_field,\n", + " train_proportion=0.60,\n", + " valid_proportion=0.15,\n", + " random_seed=random_seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note we're splitting the data on **clusters**, not records, so the record counts vary:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(26719, 6591, 11098)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(train_record_dict), len(valid_record_dict), len(test_record_dict)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preprocess" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll perform a very minimal preprocessing of the dataset. We want to simply force ASCII chars, lowercase all chars, and strip leading and trailing whitespace.\n", + "\n", + "The fields we'll clean are the ones we'll use:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "field_list = ['merchant_name','plaid_merchant','plaid_category','avg_tx_amount']" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import unidecode\n", + "\n", + "def clean_str(s):\n", + " return unidecode.unidecode(s).lower().strip()\n", + "\n", + "for record in record_dict.values():\n", + " for field in field_list:\n", + " record[field] = clean_str(record[field])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'merchant_name': 'dutchbrosll',\n", + " 'plaid_merchant': 'dutch bros. coffee',\n", + " 'plaid_category': 'food and drink restaurants coffee shop',\n", + " 'avg_tx_amount': '-10.0'}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "utils.subdict(record_dict[2], field_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Forcing ASCII chars in this dataset is useful to improve recall because there's little difference between accented and not-accented chars here. Also, this dataset contains mostly latin chars." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Entity Embed fields" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we will define how record fields will be numericalized and encoded by the neural network. First we set an `alphabet`, here we'll use ASCII numbers, letters, symbols and space:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'0123456789abcdefghijklmnopqrstuvwxyz!\"#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~ '" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from data_utils.field_config_parser import DEFAULT_ALPHABET\n", + "\n", + "alphabet = DEFAULT_ALPHABET\n", + "''.join(alphabet)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It's worth noting you can use any alphabet you need, so the accent removal we performed is optional." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we set an `field_config_dict`. It defines `field_type`s that determine how fields are processed in the neural network:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "field_config_dict = {\n", + " \n", + " 'merchant_name': {\n", + " 'field_type': \"MULTITOKEN\",\n", + " 'tokenizer': \"entity_embed.default_tokenizer\",\n", + " 'alphabet': alphabet,\n", + " 'max_str_len': None, # compute\n", + " },\n", + "\n", + " 'merchant_name_semantic': {\n", + " 'key': 'merchant_name',\n", + " 'field_type': \"SEMANTIC_MULTITOKEN\",\n", + " 'tokenizer': \"entity_embed.default_tokenizer\",\n", + " 'vocab': \"tx_embeddings_large.vec\",\n", + " 'max_str_len': None, # compute\n", + " },\n", + " 'avg_tx_amount': {\n", + " 'field_type': \"STRING\",\n", + " 'tokenizer': \"entity_embed.default_tokenizer\",\n", + " 'alphabet': alphabet,\n", + " 'max_str_len': None, # compute\n", + " },\n", + " 'avg_tx_amount_semantic': {\n", + " 'key':\"avg_tx_amount\",\n", + " 'field_type': \"SEMANTIC_STRING\",\n", + " 'tokenizer': \"entity_embed.default_tokenizer\",\n", + " 'vocab': \"tx_embeddings_large.vec\",\n", + " 'max_str_len': None, # compute\n", + " },\n", + " 'plaid_merchant': {\n", + " 'field_type': \"MULTITOKEN\",\n", + " 'tokenizer': \"entity_embed.default_tokenizer\",\n", + " 'alphabet': alphabet,\n", + " 'max_str_len': None, # compute\n", + " },\n", + " 'plaid_merchant_semantic': {\n", + " 'key': 'plaid_merchant',\n", + " 'field_type': \"SEMANTIC_MULTITOKEN\",\n", + " 'tokenizer': \"entity_embed.default_tokenizer\",\n", + " 'vocab': \"tx_embeddings_large.vec\",\n", + " },\n", + " 'plaid_category': {\n", + " 'field_type': \"MULTITOKEN\",\n", + " 'tokenizer': \"entity_embed.default_tokenizer\",\n", + " 'alphabet': alphabet,\n", + " 'max_str_len': None, # compute\n", + " },\n", + " 'plaid_category_semantic': {\n", + " 'key': 'plaid_category',\n", + " 'field_type': \"SEMANTIC_MULTITOKEN\",\n", + " 'tokenizer': \"entity_embed.default_tokenizer\",\n", + " 'vocab': \"tx_embeddings_large.vec\",\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we use our `field_config_dict` to get a `record_numericalizer`. This object will convert the strings from our records into tensors for the neural network.\n", + "\n", + "The same `record_numericalizer` must be used on ALL data: train, valid, test. This ensures numericalization will be consistent. Therefore, we pass `record_list=record_dict.values()`:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "12:55:27 INFO:For field=merchant_name, computing actual max_str_len\n", + "12:55:27 INFO:actual_max_str_len=23 must be even to enable NN pooling. Updating to 24\n", + "12:55:27 INFO:For field=merchant_name, using actual_max_str_len=24\n", + "12:55:27 INFO:Loading vectors from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/.vector_cache/tx_embeddings_large.vec.pt\n", + "12:55:29 INFO:For field=avg_tx_amount, computing actual max_str_len\n", + "12:55:29 INFO:For field=avg_tx_amount, using actual_max_str_len=8\n", + "12:55:29 INFO:Loading vectors from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/.vector_cache/tx_embeddings_large.vec.pt\n", + "12:55:30 INFO:For field=plaid_merchant, computing actual max_str_len\n", + "12:55:30 INFO:actual_max_str_len=23 must be even to enable NN pooling. Updating to 24\n", + "12:55:30 INFO:For field=plaid_merchant, using actual_max_str_len=24\n", + "12:55:31 INFO:Loading vectors from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/.vector_cache/tx_embeddings_large.vec.pt\n", + "12:55:32 INFO:For field=plaid_category, computing actual max_str_len\n", + "12:55:32 INFO:actual_max_str_len=17 must be even to enable NN pooling. Updating to 18\n", + "12:55:32 INFO:For field=plaid_category, using actual_max_str_len=18\n", + "12:55:32 INFO:Loading vectors from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/.vector_cache/tx_embeddings_large.vec.pt\n" + ] + } + ], + "source": [ + "from data_utils.field_config_parser import FieldConfigDictParser\n", + "\n", + "record_numericalizer = FieldConfigDictParser.from_dict(field_config_dict, record_list=record_dict.values())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize Data Module" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "under the hood, Entity Embed uses [pytorch-lightning](https://pytorch-lightning.readthedocs.io/en/latest/), so we need to create a datamodule object:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from entity_embed import DeduplicationDataModule\n", + "\n", + "batch_size = 32\n", + "eval_batch_size = 64\n", + "datamodule = DeduplicationDataModule(\n", + " train_record_dict=train_record_dict,\n", + " valid_record_dict=valid_record_dict,\n", + " test_record_dict=test_record_dict,\n", + " cluster_field=cluster_field,\n", + " record_numericalizer=record_numericalizer,\n", + " batch_size=batch_size,\n", + " eval_batch_size=eval_batch_size,\n", + " random_seed=random_seed,\n", + " train_loader_kwargs ={\"num_workers\":0,\"multiprocessing_context\":\"fork\"},\n", + " eval_loader_kwargs ={\"num_workers\":0,\"multiprocessing_context\":\"fork\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've used `DeduplicationDataModule` because we're doing Deduplication of a single dataset/table (a.k.a. Entity Clustering, Entity Resolution, etc.).\n", + "\n", + "We're NOT doing Record Linkage of two datasets here. Check the other notebook [Record-Linkage-Example](./Record-Linkage-Example.ipynb) if you want to learn how to do it with Entity Embed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now the training process! Thanks to pytorch-lightning, it's easy to train, validate, and test with the same datamodule.\n", + "\n", + "We must choose the K of the Approximate Nearest Neighbors, i.e., the top K neighbors our model will use to find duplicates in the embedding space. Below we're setting it on `ann_k` and initializing the `EntityEmbed` model object:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "setting embedding size to 100\n" + ] + } + ], + "source": [ + "from entity_embed_local import EntityEmbedLocal\n", + "\n", + "ann_k = 15\n", + "model = EntityEmbedLocal(\n", + " record_numericalizer,\n", + " ann_k=ann_k,\n", + " embedding_size=100,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To train, Entity Embed uses [pytorch-lightning Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html) on it's `EntityEmbed.fit` method.\n", + "\n", + "Since Entity Embed is focused in recall, we'll use `valid_recall_at_0.3` for early stopping. But we'll set `min_epochs = 5` to avoid a very low precision.\n", + "\n", + "`0.3` here is the threshold for **cosine similarity of embedding vectors**, so possible values are between -1 and 1. We're using a validation metric, and the training process will run validation on every epoch end due to `check_val_every_n_epoch=1`.\n", + "\n", + "We also set `tb_name` and `tb_save_dir` to use Tensorboard. Run `tensorboard --logdir notebooks/tb_logs` to check the train and valid metrics during and after training." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "13:50:27 INFO:GPU available: False, used: False\n", + "13:50:27 INFO:TPU available: False, using: 0 TPU cores\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "calling up trainer....\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "13:51:20 INFO:\n", + " | Name | Type | Params\n", + "-------------------------------------------\n", + "0 | blocker_net | BlockerNet | 6.6 M \n", + "1 | loss_fn | SupConLoss | 0 \n", + "-------------------------------------------\n", + "2.3 M Trainable params\n", + "4.3 M Non-trainable params\n", + "6.6 M Total params\n", + "26.554 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4: 100%|██████████| 1019/1019 [04:07<00:00, 4.12it/s, loss=0.818, v_num=27]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "14:10:50 INFO:Loading the best validation model from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/models/epoch=1-step=1832-v7.ckpt...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "setting embedding size to 100\n" + ] + } + ], + "source": [ + "\n", + "trainer = model.fit(\n", + " datamodule,\n", + " min_epochs=2,\n", + " max_epochs=5,\n", + " check_val_every_n_epoch=1,\n", + " early_stop_monitor=\"valid_recall_at_0.7\",\n", + " model_save_dir='models',\n", + " use_gpu=False,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "# !mkdir -p models/entityembed\n", + "!cp {model.trainer.checkpoint_callback.best_model_path} models/entityembed/ee-model.ckpt" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "path = \"models/entityembed/\"\n", + "\n", + "with open(path+ 'ee-train-records.json', 'w') as f:\n", + " json.dump(datamodule.train_record_dict, f, indent=4)\n", + "\n", + "with open(path+ 'ee-valid-records.json', 'w') as f:\n", + " json.dump(datamodule.valid_record_dict, f, indent=4)\n", + "\n", + "with open(path+ 'ee-test-records.json', 'w') as f:\n", + " json.dump(datamodule.test_record_dict, f, indent=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "setting embedding size to 100\n" + ] + } + ], + "source": [ + "model = EntityEmbedLocal.load_from_checkpoint('models/epoch=1-step=1832-v7.ckpt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`EntityEmbed.fit` keeps only the weights of the best validation model. With them, we can check the best performance on validation set:" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'valid_f1_at_0.3': 0.0,\n", + " 'valid_f1_at_0.5': 0.0,\n", + " 'valid_f1_at_0.7': 0.0,\n", + " 'valid_pair_entity_ratio_at_0.3': 10.15106931594115,\n", + " 'valid_pair_entity_ratio_at_0.5': 4.946155012892461,\n", + " 'valid_pair_entity_ratio_at_0.7': 1.6895191870165327,\n", + " 'valid_precision_at_0.3': 0.0,\n", + " 'valid_precision_at_0.5': 0.0,\n", + " 'valid_precision_at_0.7': 0.0,\n", + " 'valid_recall_at_0.3': 0.0,\n", + " 'valid_recall_at_0.5': 0.0,\n", + " 'valid_recall_at_0.7': 0.0}" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.validate(datamodule) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And we can check which fields are most important for the final embedding:" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'merchant_name': 0.2551257014274597,\n", + " 'merchant_name_semantic': 0.22146005928516388,\n", + " 'avg_tx_amount': 0.03779561445116997,\n", + " 'avg_tx_amount_semantic': 0.24249973893165588,\n", + " 'plaid_merchant': 0.07616355270147324,\n", + " 'plaid_merchant_semantic': 0.060125213116407394,\n", + " 'plaid_category': 0.07688859850168228,\n", + " 'plaid_category_semantic': 0.02994149550795555}" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.get_pool_weights()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again with the best validation model, we can check the performance on the test set:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "16:57:02 INFO:Test positive pair count: 30657\n" + ] + }, + { + "data": { + "text/plain": [ + "{'test_f1_at_0.3': 0.0,\n", + " 'test_f1_at_0.5': 0.0,\n", + " 'test_f1_at_0.7': 0.0,\n", + " 'test_pair_entity_ratio_at_0.3': 17.187250554323725,\n", + " 'test_pair_entity_ratio_at_0.5': 5.513303769401331,\n", + " 'test_pair_entity_ratio_at_0.7': 2.3838137472283814,\n", + " 'test_precision_at_0.3': 0.0,\n", + " 'test_precision_at_0.5': 0.0,\n", + " 'test_precision_at_0.7': 0.0,\n", + " 'test_recall_at_0.3': 0.0,\n", + " 'test_recall_at_0.5': 0.0,\n", + " 'test_recall_at_0.7': 0.0}" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.test(datamodule)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Entity Embed achieves Recall of ~0.99 with Pair-Entity ratio below 100 on a variety of datasets. **Entity Embed aims for high recall at the expense of precision. Therefore, this library is suited for the Blocking/Indexing stage of an Entity Resolution pipeline.** A scalabale and noise-tolerant Blocking procedure is often the main bottleneck for performance and quality on Entity Resolution pipelines, so this library aims to solve that. Note the ANN search on embedded records returns several candidate pairs that must be filtered to find the best matching pairs, possibly with a pairwise classifier. See the [Record-Linkage-Example](./Record-Linkage-Example.ipynb) for an example of matching." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## t-sne visualization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize a small sample of the test embeddings and see if they look properly clustered. First, get the embedding vectors:" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "# batch embedding: 100%|██████████| 286/286 [00:57<00:00, 4.95it/s]\n" + ] + } + ], + "source": [ + "test_vector_dict = model.predict(\n", + " record_dict=test_record_dict,\n", + " batch_size=eval_batch_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, produce the visualization:" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [], + "source": [ + "vis_sample_size = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [], + "source": [ + "n=20\n", + "test_cluster_dict = utils.record_dict_to_cluster_dict(test_record_dict, cluster_field)\n", + "vis_cluster_dict = dict(sorted(test_cluster_dict.items(), key=lambda x: len(x[1]), reverse=True)[n:vis_sample_size+n])\n", + "\n", + "vis_x = np.stack([test_vector_dict[id_] for cluster in vis_cluster_dict.values() for id_ in cluster])\n", + "vis_y = np.array([cluster_id for cluster_id, cluster in vis_cluster_dict.items() for __ in cluster])" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.manifold import TSNE\n", + "\n", + "tnse = TSNE(metric='cosine', perplexity=20, square_distances=True, random_state=random_seed)\n", + "tsne_results = tnse.fit_transform(vis_x)" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import itertools\n", + "import random\n", + "\n", + "plt.figure(figsize=(16,10))\n", + "ax = sns.scatterplot(\n", + " x=tsne_results[:,0],\n", + " y=tsne_results[:,1],\n", + " hue=vis_y,\n", + " palette=sns.color_palette(\"hls\", len(vis_cluster_dict.keys())),\n", + " legend=\"full\",\n", + " alpha=0.8\n", + ")\n", + "for id_, (x, y) in zip(itertools.chain.from_iterable(vis_cluster_dict.values()), tsne_results):\n", + " # text = id_\n", + " text = test_record_dict[id_]['alias'][:25]\n", + " ax.text(x + 2 , y + 2+ (5*random.random()), text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing manually (like a production run)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When running in production, you only have access to the trained `model` object and the production `record_dict` (without the `cluster_field` filled, of course).\n", + "\n", + "So let's simulate that by removing `cluster_field` from the `test_record_dict`:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "\n", + "prod_test_record_dict = copy.deepcopy(test_record_dict)\n", + "\n", + "for record in prod_test_record_dict.values():\n", + " del record[cluster_field]" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "prod_test_record_dict = {}\n", + "\n", + "merchants_list = pd.read_csv('data/us_companies.csv', header=None,names=['merchant_name']).iloc[:,0].tolist()\n", + "merchants_list\n", + "\n", + "merchants = pd.read_csv('data/merchants_list.csv',names=[\"merchant_name\",\"merchant_id\",\"plaid_merchant\",\"plaid_category\",\"avg_tx_amount\",\"count\"])\n", + "merchants = merchants.sort_values('count', ascending=False).drop_duplicates(['merchant_name','merchant_id'])\n", + "merchants.avg_tx_amount = merchants.avg_tx_amount.round()\n", + "merchants = merchants.apply(lambda x: x.astype(str).str.lower()).reset_index()\n", + "for current_record_id, record in merchants.iterrows():\n", + " prod_test_record_dict[current_record_id] = record.to_dict()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then call `predict_pairs` with some `ann_k` and `sim_threshold`:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "# batch embedding: 69%|██████▉ | 410/592 [00:39<00:19, 9.19it/s]12:56:38 WARNING:Found out of alphabet char at val=döner, char=ö\n", + "# batch embedding: 84%|████████▍ | 497/592 [00:48<00:08, 11.22it/s]12:56:47 WARNING:Found out of alphabet char at val=côte, char=ô\n", + "# batch embedding: 100%|██████████| 592/592 [00:56<00:00, 10.52it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "72547" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sim_threshold = 0.85\n", + "ann_k=15\n", + "found_pair_set = model.predict_pairs(\n", + " record_dict=prod_test_record_dict,\n", + " batch_size=eval_batch_size,\n", + " ann_k=ann_k,\n", + " sim_threshold=sim_threshold\n", + ")\n", + "len(found_pair_set)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "37865" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_merchants = merchants.merchant_name.nunique()\n", + "total_merchants\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "idf_df = pd.read_csv('data/idf_weights.csv')\n", + "idf_weights = dict(zip(idf_df.token,idf_df.idf))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 291, + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_genericity(tokens_left,tokens_right):\n", + " tokens = tokens_left + tokens_right\n", + " return (sum([1-idf_weights.get(t,1) for t in tokens])/len(tokens))\n", + "\n", + "def calculate_amount_similarity(l_amount, r_amount):\n", + " if max(l_amount,r_amount) == 0 or l_amount * r_amount < 0:\n", + " return False\n", + " l_amount, r_amount = (abs(n) for n in (l_amount, r_amount))\n", + " is_different_bracket = abs(l_amount - r_amount) > 10\n", + " is_ratio_large = (min(l_amount , r_amount)/ max(l_amount , r_amount)) < 0.8\n", + " return is_different_bracket or is_ratio_large\n", + "# calculate_genericity(prod_test_record_dict.get(3256)['merchant_name'].split(),prod_test_record_dict.get(12722)['merchant_name'].split())\n", + "# calculate_amount_similarity(82.0, -100.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 302, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "06ea0405b6e04d07801df03e080afd60", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/72547 [00:000.85 :\n", + " # print(prod_test_record_dict.get(id_left))\n", + " # print(prod_test_record_dict.get(id_right))\n", + " # print(score)\n", + " # print(\"------------------\")\n", + " graph[id_left,id_right] = 1\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 304, + "metadata": {}, + "outputs": [], + "source": [ + "graph = csr_matrix(graph)\n", + "# print(graph)" + ] + }, + { + "cell_type": "code", + "execution_count": 305, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(35960, array([ 0, 1, 2, ..., 35957, 35958, 35959], dtype=int32))" + ] + }, + "execution_count": 305, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True)\n", + "(n_components,labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 306, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ad663eec48b6484ca04d314230c8836a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/37865 [00:00 1:\n", + " dup_cluster.append([prod_test_record_dict.get(key) for key in dupes])\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "'\\n'.join(pd.Series(dup_cluster).iloc[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 337, + "metadata": {}, + "outputs": [], + "source": [ + "from pandas import option_context\n", + "\n", + "with option_context('display.max_colwidth', 400):\n", + " (pd.DataFrame(dup_cluster).apply(lambda x: '\\n'.join(x.dropna().astype(str)), axis=1)).to_csv('data/dupes.csv')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.6537699722699062" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from entity_embed.evaluation import pair_entity_ratio\n", + "\n", + "pair_entity_ratio(len(found_pair_set), len(prod_test_record_dict))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's check now the metrics of the found duplicate pairs:" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "cos_similarity = lambda a, b: np.dot(a, b)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "fields = ['merchant_name',\n", + " 'plaid_merchant',\n", + " 'plaid_category',\n", + " 'avg_tx_amount','count']" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9f613bd140204935badae4f6c7c1a079", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/72542 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
matching_similaritymerchantmatched_merchant
00.915286{'merchant_name': 'carroll county', 'plaid_mer...(carroll emc, -24.0, 26)
10.868136{'merchant_name': 'express fuels', 'plaid_merc...(express nails, -70.0, 14)
20.905586{'merchant_name': 'white', 'plaid_merchant': '...(white hart inn, -8.0, 1)
30.961777{'merchant_name': 'poke city', 'plaid_merchant...(poke poku, -24.0, 7)
40.977440{'merchant_name': 'new china fun', 'plaid_merc...(new china restaurant, -22.0, 323)
............
725370.865486{'merchant_name': 'los', 'plaid_merchant': 'na...(los cabos, -60.0, 35)
725380.926597{'merchant_name': 'mega bev', 'plaid_merchant'...(mega liquor, -22.0, 121)
725390.902664{'merchant_name': 'village inn pizza', 'plaid_...(village restaurant, -20.0, 20)
725400.882056{'merchant_name': 'el amigo', 'plaid_merchant'...(el sol, -20.0, 19)
725410.859562{'merchant_name': 'campus bo', 'plaid_merchant...(campus bookstor building, -236.0, 5)
\n", + "

72542 rows × 3 columns

\n", + "" + ], + "text/plain": [ + " matching_similarity merchant \\\n", + "0 0.915286 {'merchant_name': 'carroll county', 'plaid_mer... \n", + "1 0.868136 {'merchant_name': 'express fuels', 'plaid_merc... \n", + "2 0.905586 {'merchant_name': 'white', 'plaid_merchant': '... \n", + "3 0.961777 {'merchant_name': 'poke city', 'plaid_merchant... \n", + "4 0.977440 {'merchant_name': 'new china fun', 'plaid_merc... \n", + "... ... ... \n", + "72537 0.865486 {'merchant_name': 'los', 'plaid_merchant': 'na... \n", + "72538 0.926597 {'merchant_name': 'mega bev', 'plaid_merchant'... \n", + "72539 0.902664 {'merchant_name': 'village inn pizza', 'plaid_... \n", + "72540 0.882056 {'merchant_name': 'el amigo', 'plaid_merchant'... \n", + "72541 0.859562 {'merchant_name': 'campus bo', 'plaid_merchant... \n", + "\n", + " matched_merchant \n", + "0 (carroll emc, -24.0, 26) \n", + "1 (express nails, -70.0, 14) \n", + "2 (white hart inn, -8.0, 1) \n", + "3 (poke poku, -24.0, 7) \n", + "4 (new china restaurant, -22.0, 323) \n", + "... ... \n", + "72537 (los cabos, -60.0, 35) \n", + "72538 (mega liquor, -22.0, 121) \n", + "72539 (village restaurant, -20.0, 20) \n", + "72540 (el sol, -20.0, 19) \n", + "72541 (campus bookstor building, -236.0, 5) \n", + "\n", + "[72542 rows x 3 columns]" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from tqdm.notebook import tqdm\n", + "duplicates = pd.DataFrame(columns=['matching_similarity','merchant','matched_merchant'])\n", + "for (id_left, id_right, similarity) in tqdm(list(found_pair_set)):\n", + " merchant_a = utils.subdict(prod_test_record_dict[id_left], fields)\n", + " merchant_b = utils.subdict(prod_test_record_dict[id_right], fields)\n", + " duplicates = duplicates.append({'matching_similarity':similarity,'merchant':merchant_a,'matched_merchant':(merchant_b['merchant_name'],merchant_b['avg_tx_amount'],merchant_b['count'])},ignore_index=True)\n", + "\n", + "duplicates" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/adnanshahzada/opt/miniconda3/envs/entity-embed-env/lib/python3.8/site-packages/pandas/core/generic.py:5516: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " self[name] = value\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
matching_similaritymatched_merchant
merchant
{'merchant_name': \"america''s best\", 'plaid_merchant': \"america''s best\", 'plaid_category': 'food and drink restaurants ', 'avg_tx_amount': '-107.0', 'count': '1341'}[0.931][(america''s best wings, -22.0, 281)]
{'merchant_name': \"auntie anne''s\", 'plaid_merchant': \"auntie anne''s\", 'plaid_category': 'food and drink restaurants ', 'avg_tx_amount': '-10.0', 'count': '22971'}[0.981, 0.978][(auntie annes, -11.0, 17), (auntie annies, -8...
{'merchant_name': \"bahama buck''s\", 'plaid_merchant': \"bahama buck''s tx\", 'plaid_category': 'shops clothing and accessories ', 'avg_tx_amount': '-10.0', 'count': '39'}[0.952][(bahama bucks, -11.0, 437)]
{'merchant_name': \"baker''s dozen\", 'plaid_merchant': \"baker''s dozen\", 'plaid_category': 'shops food and beverage store ', 'avg_tx_amount': '-12.0', 'count': '4'}[0.901][(baker''s iga, -24.0, 21)]
{'merchant_name': \"bill''s liquor\", 'plaid_merchant': \"bill''s liquor\", 'plaid_category': 'shops food and beverage store ', 'avg_tx_amount': '-23.0', 'count': '19'}[0.921][(bill''s superette, -15.0, 232)]
.........
{'merchant_name': 'zoom management', 'plaid_merchant': 'nan', 'plaid_category': 'service ', 'avg_tx_amount': '-30.0', 'count': '22'}[0.903, 0.962, 0.935][(zoom tan, -23.0, 208), (zoom mart, -19.0, 14...
{'merchant_name': 'zoom mart', 'plaid_merchant': 'nan', 'plaid_category': 'service ', 'avg_tx_amount': '-19.0', 'count': '14'}[0.95, 0.926][(zoom.us, -16.0, 513), (zoom tan, -23.0, 208)]
{'merchant_name': 'zoom tan', 'plaid_merchant': 'zoom tan', 'plaid_category': 'service personal care ', 'avg_tx_amount': '-23.0', 'count': '208'}[0.927][(zoom.us, -16.0, 513)]
{'merchant_name': 'zoom', 'plaid_merchant': 'nan', 'plaid_category': 'service ', 'avg_tx_amount': '-19.0', 'count': '99'}[0.964, 0.985, 0.966, 0.933, 0.922, 0.969][(zoom.us, -16.0, 513), (zoom mart, -19.0, 14)...
{'merchant_name': 'zt', 'plaid_merchant': 'nan', 'plaid_category': 'transfer debit ', 'avg_tx_amount': '-11.0', 'count': '139'}[0.98][(zt 645, -12.0, 14)]
\n", + "

9918 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " matching_similarity \\\n", + "merchant \n", + "{'merchant_name': \"america''s best\", 'plaid_mer... [0.931] \n", + "{'merchant_name': \"auntie anne''s\", 'plaid_merc... [0.981, 0.978] \n", + "{'merchant_name': \"bahama buck''s\", 'plaid_merc... [0.952] \n", + "{'merchant_name': \"baker''s dozen\", 'plaid_merc... [0.901] \n", + "{'merchant_name': \"bill''s liquor\", 'plaid_merc... [0.921] \n", + "... ... \n", + "{'merchant_name': 'zoom management', 'plaid_mer... [0.903, 0.962, 0.935] \n", + "{'merchant_name': 'zoom mart', 'plaid_merchant'... [0.95, 0.926] \n", + "{'merchant_name': 'zoom tan', 'plaid_merchant':... [0.927] \n", + "{'merchant_name': 'zoom', 'plaid_merchant': 'na... [0.964, 0.985, 0.966, 0.933, 0.922, 0.969] \n", + "{'merchant_name': 'zt', 'plaid_merchant': 'nan'... [0.98] \n", + "\n", + " matched_merchant \n", + "merchant \n", + "{'merchant_name': \"america''s best\", 'plaid_mer... [(america''s best wings, -22.0, 281)] \n", + "{'merchant_name': \"auntie anne''s\", 'plaid_merc... [(auntie annes, -11.0, 17), (auntie annies, -8... \n", + "{'merchant_name': \"bahama buck''s\", 'plaid_merc... [(bahama bucks, -11.0, 437)] \n", + "{'merchant_name': \"baker''s dozen\", 'plaid_merc... [(baker''s iga, -24.0, 21)] \n", + "{'merchant_name': \"bill''s liquor\", 'plaid_merc... [(bill''s superette, -15.0, 232)] \n", + "... ... \n", + "{'merchant_name': 'zoom management', 'plaid_mer... [(zoom tan, -23.0, 208), (zoom mart, -19.0, 14... \n", + "{'merchant_name': 'zoom mart', 'plaid_merchant'... [(zoom.us, -16.0, 513), (zoom tan, -23.0, 208)] \n", + "{'merchant_name': 'zoom tan', 'plaid_merchant':... [(zoom.us, -16.0, 513)] \n", + "{'merchant_name': 'zoom', 'plaid_merchant': 'na... [(zoom.us, -16.0, 513), (zoom mart, -19.0, 14)... \n", + "{'merchant_name': 'zt', 'plaid_merchant': 'nan'... [(zt 645, -12.0, 14)] \n", + "\n", + "[9918 rows x 2 columns]" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "duplicates_strong = duplicates[duplicates.matching_similarity>0.9]\n", + "duplicates_strong.matching_similarity = duplicates_strong.matching_similarity.apply(lambda x: round(x,3))\n", + "duplicates_strong.merchant = duplicates_strong.merchant.astype(str)\n", + "result =duplicates_strong.groupby('merchant').agg(lambda x: list(x))\n", + "result.to_csv('data/duplicate_merchants_ct.csv')\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Baho Convenience Store', False),\n", + " ('Hilton Payroll', False),\n", + " ('Fast Shop', False),\n", + " ('Imperial Mart', False),\n", + " ('Bitterroot Beanery', False),\n", + " ('Piedmont Natural Gas', False),\n", + " ('The Island Shoppe', False),\n", + " ('Americas Best Wings', False),\n", + " ('Rockland Nails', False),\n", + " ('Irobot Corporation', False),\n", + " ('Khalil`s Food & Liquor', False),\n", + " ('Quality Auto Repair', False),\n", + " ('Jackpot Mini Mart', False),\n", + " ('Westside Convenience', False),\n", + " ('H And M Mini Sto', False),\n", + " ('Patel Corner Pantry', False),\n", + " ('Guys Pizza Downtown', False),\n", + " ('Xingyu Restaurant Inc', False),\n", + " ('5801 Video Lounge & Caf', False),\n", + " ('Just Salad', True)]" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idx_a += 20\n", + "idx_b += 20\n", + "[(r,r in merchants_list) for r in dataset.iloc[idx_a:idx_b,:]['merchant_name'].tolist()]" + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.7512554540215691, 0.12059282165133735)" + ] + }, + "execution_count": 144, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from entity_embed.evaluation import precision_and_recall\n", + "\n", + "precision_and_recall(found_pair_set, datamodule.test_pos_pair_set)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Same numbers of the `trainer.test`, so our manual testing is fine." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we can check the false positives and negatives to see if they're really difficult:" + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "12086" + ] + }, + "execution_count": 145, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "false_positives = list(found_pair_set - datamodule.test_pos_pair_set)\n", + "len(false_positives)" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "266186" + ] + }, + "execution_count": 146, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "false_negatives = list(datamodule.test_pos_pair_set - found_pair_set)\n", + "len(false_negatives)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "true_positives = list(found_pair_set - datamodule.test_pos_pair_set)\n", + "len(false_positives)\n", + "for (id_left, id_right) in list(found_pair_set)[:10]:\n", + " display(\n", + " (\n", + " record_dict[id_left],\n", + " record_dict[id_right],\n", + " cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),\n", + " utils.subdict(record_dict[id_left], field_list), utils.subdict(record_dict[id_right], field_list)\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 147, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.87486815,\n", + " {'alias': 'r&b tea', 'plaid_merchant': 'r&b tea', 'avg_tx_amount': '-6.35'},\n", + " {'alias': 'r & gs food basket',\n", + " 'plaid_merchant': 'r & gs food basket',\n", + " 'avg_tx_amount': '-21.83'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.8825392,\n", + " {'alias': 'country wide insurance',\n", + " 'plaid_merchant': 'country wide insurance',\n", + " 'avg_tx_amount': '-284.13'},\n", + " {'alias': 'country convenie',\n", + " 'plaid_merchant': 'country convenie',\n", + " 'avg_tx_amount': '-14.1'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.905407,\n", + " {'alias': 'china town 1', 'plaid_merchant': 'nan', 'avg_tx_amount': '-25.0'},\n", + " {'alias': 'china delight chinese',\n", + " 'plaid_merchant': 'china delight chinese',\n", + " 'avg_tx_amount': '-21.56'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.70510805,\n", + " {'alias': 'blue bay',\n", + " 'plaid_merchant': 'blue bay',\n", + " 'avg_tx_amount': '-12.96'},\n", + " {'alias': 'blue moon tap house',\n", + " 'plaid_merchant': 'blue moon tap house',\n", + " 'avg_tx_amount': '-26.07'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9555845,\n", + " {'alias': 'lincoln highway',\n", + " 'plaid_merchant': 'lincoln highway',\n", + " 'avg_tx_amount': '-11.24'},\n", + " {'alias': 'lincoln c mart',\n", + " 'plaid_merchant': 'lincoln c mart',\n", + " 'avg_tx_amount': '-24.99'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.95284754,\n", + " {'alias': 'quick pick grocery ltd',\n", + " 'plaid_merchant': 'nan',\n", + " 'avg_tx_amount': '-8.98'},\n", + " {'alias': 'quick pick atlanta',\n", + " 'plaid_merchant': 'quick pick',\n", + " 'avg_tx_amount': '-5.26'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.8914383,\n", + " {'alias': 'china gourmet house',\n", + " 'plaid_merchant': 'china gourmet house',\n", + " 'avg_tx_amount': '-23.85'},\n", + " {'alias': 'china town 1', 'plaid_merchant': 'nan', 'avg_tx_amount': '-36.75'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9162289,\n", + " {'alias': 'taco heads fort worth',\n", + " 'plaid_merchant': 'taco heads fort worth',\n", + " 'avg_tx_amount': '-18.76'},\n", + " {'alias': 'taco monster',\n", + " 'plaid_merchant': 'taco monster',\n", + " 'avg_tx_amount': '-29.71'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.95120347,\n", + " {'alias': 'noodles pho u',\n", + " 'plaid_merchant': 'noodles pho u',\n", + " 'avg_tx_amount': '-33.58'},\n", + " {'alias': 'noodles & company',\n", + " 'plaid_merchant': 'noodles & company',\n", + " 'avg_tx_amount': '-30.15'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9241347,\n", + " {'alias': 'noodles and dumplings',\n", + " 'plaid_merchant': 'noodles and dumplings',\n", + " 'avg_tx_amount': '-48.14'},\n", + " {'alias': 'noodles & company',\n", + " 'plaid_merchant': 'noodles & company',\n", + " 'avg_tx_amount': '-8.7'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for (id_left, id_right) in false_positives[:10]:\n", + " display(\n", + " (\n", + " cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),\n", + " utils.subdict(record_dict[id_left], field_list), utils.subdict(record_dict[id_right], field_list)\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 149, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.99850756,\n", + " {'alias': 'kwik trip',\n", + " 'plaid_merchant': 'kwik trip',\n", + " 'avg_tx_amount': '-20.42'},\n", + " {'alias': 'kwik trip',\n", + " 'plaid_merchant': 'kwik trip',\n", + " 'avg_tx_amount': '-18.5'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.99695677,\n", + " {'alias': 'rent-a-center',\n", + " 'plaid_merchant': 'rent-a-center',\n", + " 'avg_tx_amount': '-121.85'},\n", + " {'alias': 'rent-a-center',\n", + " 'plaid_merchant': 'rent-a-center',\n", + " 'avg_tx_amount': '-38.35'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9966733,\n", + " {'alias': 'kwik trip',\n", + " 'plaid_merchant': 'kwik trip',\n", + " 'avg_tx_amount': '-17.18'},\n", + " {'alias': 'kwik trip',\n", + " 'plaid_merchant': 'kwik trip',\n", + " 'avg_tx_amount': '-26.18'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.97021323,\n", + " {'alias': 'cardtronics', 'plaid_merchant': 'nan', 'avg_tx_amount': '-43.5'},\n", + " {'alias': 'cardtronics',\n", + " 'plaid_merchant': 'cardtronics',\n", + " 'avg_tx_amount': '-41.16'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9971533,\n", + " {'alias': 'stop & shop',\n", + " 'plaid_merchant': 'stop & shop',\n", + " 'avg_tx_amount': '-23.7'},\n", + " {'alias': 'stop & shop',\n", + " 'plaid_merchant': 'stop & shop',\n", + " 'avg_tx_amount': '-7.62'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.99700713,\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-18.72'},\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-20.54'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9975728,\n", + " {'alias': \"dave & buster''s\",\n", + " 'plaid_merchant': 'dave & busters',\n", + " 'avg_tx_amount': '-35.75'},\n", + " {'alias': \"dave & buster''s\",\n", + " 'plaid_merchant': 'dave & busters',\n", + " 'avg_tx_amount': '-77.44'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9961077,\n", + " {'alias': 'kwik trip',\n", + " 'plaid_merchant': 'kwik trip',\n", + " 'avg_tx_amount': '-14.02'},\n", + " {'alias': 'kwik trip',\n", + " 'plaid_merchant': 'kwik trip',\n", + " 'avg_tx_amount': '-9.2'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.99612856,\n", + " {'alias': \"dave & buster''s\",\n", + " 'plaid_merchant': 'dave & busters',\n", + " 'avg_tx_amount': '-46.28'},\n", + " {'alias': \"dave & buster''s\",\n", + " 'plaid_merchant': \"dave & buster''s\",\n", + " 'avg_tx_amount': '-33.63'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9968643,\n", + " {'alias': 'tractor supply',\n", + " 'plaid_merchant': 'tractor supply',\n", + " 'avg_tx_amount': '-28.6'},\n", + " {'alias': 'tractor supply',\n", + " 'plaid_merchant': 'tractor supply',\n", + " 'avg_tx_amount': '-54.1'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9972248,\n", + " {'alias': 'stop & shop',\n", + " 'plaid_merchant': 'stop & shop',\n", + " 'avg_tx_amount': '-32.71'},\n", + " {'alias': 'stop & shop',\n", + " 'plaid_merchant': 'stop & shop',\n", + " 'avg_tx_amount': '-10.57'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9969662,\n", + " {'alias': 'cardtronics',\n", + " 'plaid_merchant': 'cardtronics',\n", + " 'avg_tx_amount': '-184.65'},\n", + " {'alias': 'cardtronics',\n", + " 'plaid_merchant': 'cardtronics',\n", + " 'avg_tx_amount': '-52.5'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9916726,\n", + " {'alias': 'urban air killeen',\n", + " 'plaid_merchant': 'urban air killeen',\n", + " 'avg_tx_amount': '-36.15'},\n", + " {'alias': 'urban air',\n", + " 'plaid_merchant': 'urban air',\n", + " 'avg_tx_amount': '-19.75'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9962817,\n", + " {'alias': 'stop & shop',\n", + " 'plaid_merchant': 'stop & shop',\n", + " 'avg_tx_amount': '-98.06'},\n", + " {'alias': 'stop & shop',\n", + " 'plaid_merchant': 'stop & shop',\n", + " 'avg_tx_amount': '-10.69'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.997394,\n", + " {'alias': 'stop & shop',\n", + " 'plaid_merchant': 'stop & shop',\n", + " 'avg_tx_amount': '-5.0'},\n", + " {'alias': 'stop & shop',\n", + " 'plaid_merchant': 'stop & shop',\n", + " 'avg_tx_amount': '-7.48'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.99658275,\n", + " {'alias': 'rent-a-center',\n", + " 'plaid_merchant': 'rent-a-center',\n", + " 'avg_tx_amount': '-152.17'},\n", + " {'alias': 'rent-a-center',\n", + " 'plaid_merchant': 'rent-a-center',\n", + " 'avg_tx_amount': '-344.69'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.99650383,\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-24.02'},\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-6.12'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.997722,\n", + " {'alias': 'krispy kreme',\n", + " 'plaid_merchant': 'krispy kreme',\n", + " 'avg_tx_amount': '-12.15'},\n", + " {'alias': 'krispy kreme',\n", + " 'plaid_merchant': 'krispy kreme',\n", + " 'avg_tx_amount': '-11.94'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.99689335,\n", + " {'alias': 'rent-a-center',\n", + " 'plaid_merchant': 'rent-a-center',\n", + " 'avg_tx_amount': '-8.94'},\n", + " {'alias': 'rent-a-center',\n", + " 'plaid_merchant': 'rent-a-center',\n", + " 'avg_tx_amount': '-58.77'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9976802,\n", + " {'alias': 'cardtronics',\n", + " 'plaid_merchant': 'cardtronics',\n", + " 'avg_tx_amount': '-40.0'},\n", + " {'alias': 'cardtronics',\n", + " 'plaid_merchant': 'cardtronics',\n", + " 'avg_tx_amount': '-52.72'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.99777186,\n", + " {'alias': 'chipotle mexican grill',\n", + " 'plaid_merchant': 'chipotle mexican grill',\n", + " 'avg_tx_amount': '-14.17'},\n", + " {'alias': 'chipotle mexican grill',\n", + " 'plaid_merchant': 'chipotle mexican grill',\n", + " 'avg_tx_amount': '-12.07'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9955483,\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-12.93'},\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-23.91'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9961946,\n", + " {'alias': \"dave & buster''s\",\n", + " 'plaid_merchant': \"dave & buster''s\",\n", + " 'avg_tx_amount': '-38.62'},\n", + " {'alias': \"dave & buster''s\",\n", + " 'plaid_merchant': \"dave & buster''s\",\n", + " 'avg_tx_amount': '-70.86'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9975935,\n", + " {'alias': 'holiday inn',\n", + " 'plaid_merchant': 'holiday inn',\n", + " 'avg_tx_amount': '-296.98'},\n", + " {'alias': 'holiday inn',\n", + " 'plaid_merchant': 'holiday inn',\n", + " 'avg_tx_amount': '-39.1'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9973572,\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-48.81'},\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-8.24'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.99457943,\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-12.56'},\n", + " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-36.24'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9984803,\n", + " {'alias': 'kwik trip',\n", + " 'plaid_merchant': 'kwik trip',\n", + " 'avg_tx_amount': '-25.21'},\n", + " {'alias': 'kwik trip',\n", + " 'plaid_merchant': 'kwik trip',\n", + " 'avg_tx_amount': '-28.34'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.97368807,\n", + " {'alias': 'cardtronics', 'plaid_merchant': 'nan', 'avg_tx_amount': '-51.75'},\n", + " {'alias': 'cardtronics',\n", + " 'plaid_merchant': 'cardtronics',\n", + " 'avg_tx_amount': '-89.69'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9977134,\n", + " {'alias': 'gulf oil',\n", + " 'plaid_merchant': 'gulf oil',\n", + " 'avg_tx_amount': '-19.08'},\n", + " {'alias': 'gulf oil',\n", + " 'plaid_merchant': 'gulf oil',\n", + " 'avg_tx_amount': '-34.91'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(0.9969918,\n", + " {'alias': 'marshalls',\n", + " 'plaid_merchant': 'marshalls',\n", + " 'avg_tx_amount': '-76.19'},\n", + " {'alias': 'marshalls',\n", + " 'plaid_merchant': 'marshalls',\n", + " 'avg_tx_amount': '-43.74'})" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for (id_left, id_right) in false_negatives[:30]:\n", + " display(\n", + " (\n", + " cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),\n", + " utils.subdict(record_dict[id_left], field_list), utils.subdict(record_dict[id_right], field_list)\n", + " )\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.13", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "vscode": { + "interpreter": { + "hash": "e3bb7d66ba21cc372144d4e6f3a54e31b034566124f778bc0ae068d657400bc6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 7ee7f7cd973c8d22e6c692f0e8aeb659744106bb Mon Sep 17 00:00:00 2001 From: Benj Pettit Date: Tue, 25 Oct 2022 15:24:09 +0100 Subject: [PATCH 02/17] feat(data_utils): enable custom transaction embedding --- entity_embed/data_utils/field_config_parser.py | 8 ++++++-- entity_embed/data_utils/numericalizer.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/entity_embed/data_utils/field_config_parser.py b/entity_embed/data_utils/field_config_parser.py index e548235..84a8ccf 100644 --- a/entity_embed/data_utils/field_config_parser.py +++ b/entity_embed/data_utils/field_config_parser.py @@ -2,7 +2,7 @@ import logging from importlib import import_module -from torchtext.vocab import Vocab +from torchtext.vocab import Vocab, Vectors from .numericalizer import ( AVAILABLE_VOCABS, @@ -93,7 +93,11 @@ def _parse_field_config(cls, field, field_config, record_list): "an field name." ) vocab = Vocab(vocab_counter) - vocab.load_vectors(vocab_type) + if vocab_type in {'tx_embeddings_large.vec','tx_embeddings.vec'}: + vectors = Vectors(vocab_type, cache='.vector_cache') + vocab.load_vectors(vectors) + else: + vocab.load_vectors(vocab_type) # Compute max_str_len if necessary if field_type in (FieldType.STRING, FieldType.MULTITOKEN) and (max_str_len is None): diff --git a/entity_embed/data_utils/numericalizer.py b/entity_embed/data_utils/numericalizer.py index dbdbdc8..4bd90df 100644 --- a/entity_embed/data_utils/numericalizer.py +++ b/entity_embed/data_utils/numericalizer.py @@ -27,6 +27,7 @@ "glove.6B.100d", "glove.6B.200d", "glove.6B.300d", + "tx_embeddings_large.vec", ] From 8764297ffa6242e4d294f422d76ebc81bf7ee4ba Mon Sep 17 00:00:00 2001 From: Benj Pettit Date: Tue, 25 Oct 2022 15:26:57 +0100 Subject: [PATCH 03/17] Revert "notebooks: added merchant deduplication notebook" This reverts commit a2cf02b72bb4c382902a861d4366c5f075128f9c. --- notebooks/Deduplication-merchants.ipynb | 2320 ----------------------- 1 file changed, 2320 deletions(-) delete mode 100755 notebooks/Deduplication-merchants.ipynb diff --git a/notebooks/Deduplication-merchants.ipynb b/notebooks/Deduplication-merchants.ipynb deleted file mode 100755 index 996a043..0000000 --- a/notebooks/Deduplication-merchants.ipynb +++ /dev/null @@ -1,2320 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Deduplication Example" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Boilerplate" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from importlib import reload\n", - "import logging\n", - "import torch\n", - "import json\n", - "reload(logging)\n", - "logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%H:%M:%S')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.insert(0, '..')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import entity_embed_local\n", - "from entity_embed_local import EntityEmbedLocal\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "torch.set_num_threads(1)\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "\n", - "random_seed = 40\n", - "torch.manual_seed(random_seed)\n", - "np.random.seed(random_seed)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load Dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's download the CSV dataset to a temporary directory:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data = pd.read_csv('data/alias_data.csv',names=['merchant_name','merchant_id','plaid_merchant','plaid_category','avg_tx_amount','count'])\n", - "dataset = data.groupby('merchant_id').head(200)\n", - "dataset.avg_tx_amount = dataset.avg_tx_amount.round()\n", - "dataset = dataset.astype(str)\n", - "# dataset = dataset[dataset.merchant_id.isin(dataset.merchant_id.unique()[:50000])]\n", - "dataset['cluster_id'] = pd.factorize(dataset['merchant_id'].tolist())[0]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "mapping = pd.Series(dataset.merchant_name.values,index=dataset.cluster_id).to_dict()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "record_dict = {}\n", - "cluster_field = 'cluster_id'\n", - "\n", - "for current_record_id, record in dataset[['cluster_id','merchant_name','plaid_merchant','plaid_category','avg_tx_amount']].iterrows():\n", - " record['id'] = current_record_id\n", - " record[cluster_field] = int(record[cluster_field]) # convert cluster_field to int\n", - " record_dict[current_record_id] = record.to_dict()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "How many clusters this dataset has?" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "35104" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cluster_total = len(set(record[cluster_field] for record in record_dict.values()))\n", - "cluster_total" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "From all clusters, we'll use only 50% for training, and other 15% for validation to test how well we can generalize:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "12:55:26 INFO:Singleton cluster sizes (train, valid, test):(18649, 4662, 7772)\n", - "12:55:26 INFO:Plural cluster sizes (train, valid, test):(2412, 603, 1006)\n" - ] - } - ], - "source": [ - "from data_utils import utils\n", - "\n", - "train_record_dict, valid_record_dict, test_record_dict = utils.split_record_dict_on_clusters(\n", - " record_dict=record_dict,\n", - " cluster_field=cluster_field,\n", - " train_proportion=0.60,\n", - " valid_proportion=0.15,\n", - " random_seed=random_seed)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note we're splitting the data on **clusters**, not records, so the record counts vary:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(26719, 6591, 11098)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(train_record_dict), len(valid_record_dict), len(test_record_dict)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preprocess" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We'll perform a very minimal preprocessing of the dataset. We want to simply force ASCII chars, lowercase all chars, and strip leading and trailing whitespace.\n", - "\n", - "The fields we'll clean are the ones we'll use:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "field_list = ['merchant_name','plaid_merchant','plaid_category','avg_tx_amount']" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import unidecode\n", - "\n", - "def clean_str(s):\n", - " return unidecode.unidecode(s).lower().strip()\n", - "\n", - "for record in record_dict.values():\n", - " for field in field_list:\n", - " record[field] = clean_str(record[field])" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'merchant_name': 'dutchbrosll',\n", - " 'plaid_merchant': 'dutch bros. coffee',\n", - " 'plaid_category': 'food and drink restaurants coffee shop',\n", - " 'avg_tx_amount': '-10.0'}" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "utils.subdict(record_dict[2], field_list)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Forcing ASCII chars in this dataset is useful to improve recall because there's little difference between accented and not-accented chars here. Also, this dataset contains mostly latin chars." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Configure Entity Embed fields" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we will define how record fields will be numericalized and encoded by the neural network. First we set an `alphabet`, here we'll use ASCII numbers, letters, symbols and space:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'0123456789abcdefghijklmnopqrstuvwxyz!\"#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~ '" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from data_utils.field_config_parser import DEFAULT_ALPHABET\n", - "\n", - "alphabet = DEFAULT_ALPHABET\n", - "''.join(alphabet)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It's worth noting you can use any alphabet you need, so the accent removal we performed is optional." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then we set an `field_config_dict`. It defines `field_type`s that determine how fields are processed in the neural network:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "field_config_dict = {\n", - " \n", - " 'merchant_name': {\n", - " 'field_type': \"MULTITOKEN\",\n", - " 'tokenizer': \"entity_embed.default_tokenizer\",\n", - " 'alphabet': alphabet,\n", - " 'max_str_len': None, # compute\n", - " },\n", - "\n", - " 'merchant_name_semantic': {\n", - " 'key': 'merchant_name',\n", - " 'field_type': \"SEMANTIC_MULTITOKEN\",\n", - " 'tokenizer': \"entity_embed.default_tokenizer\",\n", - " 'vocab': \"tx_embeddings_large.vec\",\n", - " 'max_str_len': None, # compute\n", - " },\n", - " 'avg_tx_amount': {\n", - " 'field_type': \"STRING\",\n", - " 'tokenizer': \"entity_embed.default_tokenizer\",\n", - " 'alphabet': alphabet,\n", - " 'max_str_len': None, # compute\n", - " },\n", - " 'avg_tx_amount_semantic': {\n", - " 'key':\"avg_tx_amount\",\n", - " 'field_type': \"SEMANTIC_STRING\",\n", - " 'tokenizer': \"entity_embed.default_tokenizer\",\n", - " 'vocab': \"tx_embeddings_large.vec\",\n", - " 'max_str_len': None, # compute\n", - " },\n", - " 'plaid_merchant': {\n", - " 'field_type': \"MULTITOKEN\",\n", - " 'tokenizer': \"entity_embed.default_tokenizer\",\n", - " 'alphabet': alphabet,\n", - " 'max_str_len': None, # compute\n", - " },\n", - " 'plaid_merchant_semantic': {\n", - " 'key': 'plaid_merchant',\n", - " 'field_type': \"SEMANTIC_MULTITOKEN\",\n", - " 'tokenizer': \"entity_embed.default_tokenizer\",\n", - " 'vocab': \"tx_embeddings_large.vec\",\n", - " },\n", - " 'plaid_category': {\n", - " 'field_type': \"MULTITOKEN\",\n", - " 'tokenizer': \"entity_embed.default_tokenizer\",\n", - " 'alphabet': alphabet,\n", - " 'max_str_len': None, # compute\n", - " },\n", - " 'plaid_category_semantic': {\n", - " 'key': 'plaid_category',\n", - " 'field_type': \"SEMANTIC_MULTITOKEN\",\n", - " 'tokenizer': \"entity_embed.default_tokenizer\",\n", - " 'vocab': \"tx_embeddings_large.vec\",\n", - " },\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then we use our `field_config_dict` to get a `record_numericalizer`. This object will convert the strings from our records into tensors for the neural network.\n", - "\n", - "The same `record_numericalizer` must be used on ALL data: train, valid, test. This ensures numericalization will be consistent. Therefore, we pass `record_list=record_dict.values()`:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "12:55:27 INFO:For field=merchant_name, computing actual max_str_len\n", - "12:55:27 INFO:actual_max_str_len=23 must be even to enable NN pooling. Updating to 24\n", - "12:55:27 INFO:For field=merchant_name, using actual_max_str_len=24\n", - "12:55:27 INFO:Loading vectors from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/.vector_cache/tx_embeddings_large.vec.pt\n", - "12:55:29 INFO:For field=avg_tx_amount, computing actual max_str_len\n", - "12:55:29 INFO:For field=avg_tx_amount, using actual_max_str_len=8\n", - "12:55:29 INFO:Loading vectors from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/.vector_cache/tx_embeddings_large.vec.pt\n", - "12:55:30 INFO:For field=plaid_merchant, computing actual max_str_len\n", - "12:55:30 INFO:actual_max_str_len=23 must be even to enable NN pooling. Updating to 24\n", - "12:55:30 INFO:For field=plaid_merchant, using actual_max_str_len=24\n", - "12:55:31 INFO:Loading vectors from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/.vector_cache/tx_embeddings_large.vec.pt\n", - "12:55:32 INFO:For field=plaid_category, computing actual max_str_len\n", - "12:55:32 INFO:actual_max_str_len=17 must be even to enable NN pooling. Updating to 18\n", - "12:55:32 INFO:For field=plaid_category, using actual_max_str_len=18\n", - "12:55:32 INFO:Loading vectors from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/.vector_cache/tx_embeddings_large.vec.pt\n" - ] - } - ], - "source": [ - "from data_utils.field_config_parser import FieldConfigDictParser\n", - "\n", - "record_numericalizer = FieldConfigDictParser.from_dict(field_config_dict, record_list=record_dict.values())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Initialize Data Module" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "under the hood, Entity Embed uses [pytorch-lightning](https://pytorch-lightning.readthedocs.io/en/latest/), so we need to create a datamodule object:" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from entity_embed import DeduplicationDataModule\n", - "\n", - "batch_size = 32\n", - "eval_batch_size = 64\n", - "datamodule = DeduplicationDataModule(\n", - " train_record_dict=train_record_dict,\n", - " valid_record_dict=valid_record_dict,\n", - " test_record_dict=test_record_dict,\n", - " cluster_field=cluster_field,\n", - " record_numericalizer=record_numericalizer,\n", - " batch_size=batch_size,\n", - " eval_batch_size=eval_batch_size,\n", - " random_seed=random_seed,\n", - " train_loader_kwargs ={\"num_workers\":0,\"multiprocessing_context\":\"fork\"},\n", - " eval_loader_kwargs ={\"num_workers\":0,\"multiprocessing_context\":\"fork\"},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We've used `DeduplicationDataModule` because we're doing Deduplication of a single dataset/table (a.k.a. Entity Clustering, Entity Resolution, etc.).\n", - "\n", - "We're NOT doing Record Linkage of two datasets here. Check the other notebook [Record-Linkage-Example](./Record-Linkage-Example.ipynb) if you want to learn how to do it with Entity Embed." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now the training process! Thanks to pytorch-lightning, it's easy to train, validate, and test with the same datamodule.\n", - "\n", - "We must choose the K of the Approximate Nearest Neighbors, i.e., the top K neighbors our model will use to find duplicates in the embedding space. Below we're setting it on `ann_k` and initializing the `EntityEmbed` model object:" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "setting embedding size to 100\n" - ] - } - ], - "source": [ - "from entity_embed_local import EntityEmbedLocal\n", - "\n", - "ann_k = 15\n", - "model = EntityEmbedLocal(\n", - " record_numericalizer,\n", - " ann_k=ann_k,\n", - " embedding_size=100,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To train, Entity Embed uses [pytorch-lightning Trainer](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html) on it's `EntityEmbed.fit` method.\n", - "\n", - "Since Entity Embed is focused in recall, we'll use `valid_recall_at_0.3` for early stopping. But we'll set `min_epochs = 5` to avoid a very low precision.\n", - "\n", - "`0.3` here is the threshold for **cosine similarity of embedding vectors**, so possible values are between -1 and 1. We're using a validation metric, and the training process will run validation on every epoch end due to `check_val_every_n_epoch=1`.\n", - "\n", - "We also set `tb_name` and `tb_save_dir` to use Tensorboard. Run `tensorboard --logdir notebooks/tb_logs` to check the train and valid metrics during and after training." - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "13:50:27 INFO:GPU available: False, used: False\n", - "13:50:27 INFO:TPU available: False, using: 0 TPU cores\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "calling up trainer....\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "13:51:20 INFO:\n", - " | Name | Type | Params\n", - "-------------------------------------------\n", - "0 | blocker_net | BlockerNet | 6.6 M \n", - "1 | loss_fn | SupConLoss | 0 \n", - "-------------------------------------------\n", - "2.3 M Trainable params\n", - "4.3 M Non-trainable params\n", - "6.6 M Total params\n", - "26.554 Total estimated model params size (MB)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 4: 100%|██████████| 1019/1019 [04:07<00:00, 4.12it/s, loss=0.818, v_num=27]" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "14:10:50 INFO:Loading the best validation model from /Users/adnanshahzada/Cleo/Repos/deduplication/entity-embed/entity_embed/models/epoch=1-step=1832-v7.ckpt...\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "setting embedding size to 100\n" - ] - } - ], - "source": [ - "\n", - "trainer = model.fit(\n", - " datamodule,\n", - " min_epochs=2,\n", - " max_epochs=5,\n", - " check_val_every_n_epoch=1,\n", - " early_stop_monitor=\"valid_recall_at_0.7\",\n", - " model_save_dir='models',\n", - " use_gpu=False,\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "# !mkdir -p models/entityembed\n", - "!cp {model.trainer.checkpoint_callback.best_model_path} models/entityembed/ee-model.ckpt" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [], - "source": [ - "path = \"models/entityembed/\"\n", - "\n", - "with open(path+ 'ee-train-records.json', 'w') as f:\n", - " json.dump(datamodule.train_record_dict, f, indent=4)\n", - "\n", - "with open(path+ 'ee-valid-records.json', 'w') as f:\n", - " json.dump(datamodule.valid_record_dict, f, indent=4)\n", - "\n", - "with open(path+ 'ee-test-records.json', 'w') as f:\n", - " json.dump(datamodule.test_record_dict, f, indent=4)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "setting embedding size to 100\n" - ] - } - ], - "source": [ - "model = EntityEmbedLocal.load_from_checkpoint('models/epoch=1-step=1832-v7.ckpt')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`EntityEmbed.fit` keeps only the weights of the best validation model. With them, we can check the best performance on validation set:" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'valid_f1_at_0.3': 0.0,\n", - " 'valid_f1_at_0.5': 0.0,\n", - " 'valid_f1_at_0.7': 0.0,\n", - " 'valid_pair_entity_ratio_at_0.3': 10.15106931594115,\n", - " 'valid_pair_entity_ratio_at_0.5': 4.946155012892461,\n", - " 'valid_pair_entity_ratio_at_0.7': 1.6895191870165327,\n", - " 'valid_precision_at_0.3': 0.0,\n", - " 'valid_precision_at_0.5': 0.0,\n", - " 'valid_precision_at_0.7': 0.0,\n", - " 'valid_recall_at_0.3': 0.0,\n", - " 'valid_recall_at_0.5': 0.0,\n", - " 'valid_recall_at_0.7': 0.0}" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.validate(datamodule) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And we can check which fields are most important for the final embedding:" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'merchant_name': 0.2551257014274597,\n", - " 'merchant_name_semantic': 0.22146005928516388,\n", - " 'avg_tx_amount': 0.03779561445116997,\n", - " 'avg_tx_amount_semantic': 0.24249973893165588,\n", - " 'plaid_merchant': 0.07616355270147324,\n", - " 'plaid_merchant_semantic': 0.060125213116407394,\n", - " 'plaid_category': 0.07688859850168228,\n", - " 'plaid_category_semantic': 0.02994149550795555}" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.get_pool_weights()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Testing" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Again with the best validation model, we can check the performance on the test set:" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "16:57:02 INFO:Test positive pair count: 30657\n" - ] - }, - { - "data": { - "text/plain": [ - "{'test_f1_at_0.3': 0.0,\n", - " 'test_f1_at_0.5': 0.0,\n", - " 'test_f1_at_0.7': 0.0,\n", - " 'test_pair_entity_ratio_at_0.3': 17.187250554323725,\n", - " 'test_pair_entity_ratio_at_0.5': 5.513303769401331,\n", - " 'test_pair_entity_ratio_at_0.7': 2.3838137472283814,\n", - " 'test_precision_at_0.3': 0.0,\n", - " 'test_precision_at_0.5': 0.0,\n", - " 'test_precision_at_0.7': 0.0,\n", - " 'test_recall_at_0.3': 0.0,\n", - " 'test_recall_at_0.5': 0.0,\n", - " 'test_recall_at_0.7': 0.0}" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.test(datamodule)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Entity Embed achieves Recall of ~0.99 with Pair-Entity ratio below 100 on a variety of datasets. **Entity Embed aims for high recall at the expense of precision. Therefore, this library is suited for the Blocking/Indexing stage of an Entity Resolution pipeline.** A scalabale and noise-tolerant Blocking procedure is often the main bottleneck for performance and quality on Entity Resolution pipelines, so this library aims to solve that. Note the ANN search on embedded records returns several candidate pairs that must be filtered to find the best matching pairs, possibly with a pairwise classifier. See the [Record-Linkage-Example](./Record-Linkage-Example.ipynb) for an example of matching." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## t-sne visualization" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's visualize a small sample of the test embeddings and see if they look properly clustered. First, get the embedding vectors:" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "# batch embedding: 100%|██████████| 286/286 [00:57<00:00, 4.95it/s]\n" - ] - } - ], - "source": [ - "test_vector_dict = model.predict(\n", - " record_dict=test_record_dict,\n", - " batch_size=eval_batch_size\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then, produce the visualization:" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "metadata": {}, - "outputs": [], - "source": [ - "vis_sample_size = 10" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "metadata": {}, - "outputs": [], - "source": [ - "n=20\n", - "test_cluster_dict = utils.record_dict_to_cluster_dict(test_record_dict, cluster_field)\n", - "vis_cluster_dict = dict(sorted(test_cluster_dict.items(), key=lambda x: len(x[1]), reverse=True)[n:vis_sample_size+n])\n", - "\n", - "vis_x = np.stack([test_vector_dict[id_] for cluster in vis_cluster_dict.values() for id_ in cluster])\n", - "vis_y = np.array([cluster_id for cluster_id, cluster in vis_cluster_dict.items() for __ in cluster])" - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "metadata": {}, - "outputs": [], - "source": [ - "from sklearn.manifold import TSNE\n", - "\n", - "tnse = TSNE(metric='cosine', perplexity=20, square_distances=True, random_state=random_seed)\n", - "tsne_results = tnse.fit_transform(vis_x)" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "import itertools\n", - "import random\n", - "\n", - "plt.figure(figsize=(16,10))\n", - "ax = sns.scatterplot(\n", - " x=tsne_results[:,0],\n", - " y=tsne_results[:,1],\n", - " hue=vis_y,\n", - " palette=sns.color_palette(\"hls\", len(vis_cluster_dict.keys())),\n", - " legend=\"full\",\n", - " alpha=0.8\n", - ")\n", - "for id_, (x, y) in zip(itertools.chain.from_iterable(vis_cluster_dict.values()), tsne_results):\n", - " # text = id_\n", - " text = test_record_dict[id_]['alias'][:25]\n", - " ax.text(x + 2 , y + 2+ (5*random.random()), text)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Testing manually (like a production run)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "When running in production, you only have access to the trained `model` object and the production `record_dict` (without the `cluster_field` filled, of course).\n", - "\n", - "So let's simulate that by removing `cluster_field` from the `test_record_dict`:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "import copy\n", - "\n", - "prod_test_record_dict = copy.deepcopy(test_record_dict)\n", - "\n", - "for record in prod_test_record_dict.values():\n", - " del record[cluster_field]" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "prod_test_record_dict = {}\n", - "\n", - "merchants_list = pd.read_csv('data/us_companies.csv', header=None,names=['merchant_name']).iloc[:,0].tolist()\n", - "merchants_list\n", - "\n", - "merchants = pd.read_csv('data/merchants_list.csv',names=[\"merchant_name\",\"merchant_id\",\"plaid_merchant\",\"plaid_category\",\"avg_tx_amount\",\"count\"])\n", - "merchants = merchants.sort_values('count', ascending=False).drop_duplicates(['merchant_name','merchant_id'])\n", - "merchants.avg_tx_amount = merchants.avg_tx_amount.round()\n", - "merchants = merchants.apply(lambda x: x.astype(str).str.lower()).reset_index()\n", - "for current_record_id, record in merchants.iterrows():\n", - " prod_test_record_dict[current_record_id] = record.to_dict()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Then call `predict_pairs` with some `ann_k` and `sim_threshold`:" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "# batch embedding: 69%|██████▉ | 410/592 [00:39<00:19, 9.19it/s]12:56:38 WARNING:Found out of alphabet char at val=döner, char=ö\n", - "# batch embedding: 84%|████████▍ | 497/592 [00:48<00:08, 11.22it/s]12:56:47 WARNING:Found out of alphabet char at val=côte, char=ô\n", - "# batch embedding: 100%|██████████| 592/592 [00:56<00:00, 10.52it/s]\n" - ] - }, - { - "data": { - "text/plain": [ - "72547" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sim_threshold = 0.85\n", - "ann_k=15\n", - "found_pair_set = model.predict_pairs(\n", - " record_dict=prod_test_record_dict,\n", - " batch_size=eval_batch_size,\n", - " ann_k=ann_k,\n", - " sim_threshold=sim_threshold\n", - ")\n", - "len(found_pair_set)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "37865" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "total_merchants = merchants.merchant_name.nunique()\n", - "total_merchants\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "idf_df = pd.read_csv('data/idf_weights.csv')\n", - "idf_weights = dict(zip(idf_df.token,idf_df.idf))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 291, - "metadata": {}, - "outputs": [], - "source": [ - "def calculate_genericity(tokens_left,tokens_right):\n", - " tokens = tokens_left + tokens_right\n", - " return (sum([1-idf_weights.get(t,1) for t in tokens])/len(tokens))\n", - "\n", - "def calculate_amount_similarity(l_amount, r_amount):\n", - " if max(l_amount,r_amount) == 0 or l_amount * r_amount < 0:\n", - " return False\n", - " l_amount, r_amount = (abs(n) for n in (l_amount, r_amount))\n", - " is_different_bracket = abs(l_amount - r_amount) > 10\n", - " is_ratio_large = (min(l_amount , r_amount)/ max(l_amount , r_amount)) < 0.8\n", - " return is_different_bracket or is_ratio_large\n", - "# calculate_genericity(prod_test_record_dict.get(3256)['merchant_name'].split(),prod_test_record_dict.get(12722)['merchant_name'].split())\n", - "# calculate_amount_similarity(82.0, -100.0)" - ] - }, - { - "cell_type": "code", - "execution_count": 302, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "06ea0405b6e04d07801df03e080afd60", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/72547 [00:000.85 :\n", - " # print(prod_test_record_dict.get(id_left))\n", - " # print(prod_test_record_dict.get(id_right))\n", - " # print(score)\n", - " # print(\"------------------\")\n", - " graph[id_left,id_right] = 1\n", - " \n" - ] - }, - { - "cell_type": "code", - "execution_count": 304, - "metadata": {}, - "outputs": [], - "source": [ - "graph = csr_matrix(graph)\n", - "# print(graph)" - ] - }, - { - "cell_type": "code", - "execution_count": 305, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(35960, array([ 0, 1, 2, ..., 35957, 35958, 35959], dtype=int32))" - ] - }, - "execution_count": 305, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True)\n", - "(n_components,labels)" - ] - }, - { - "cell_type": "code", - "execution_count": 306, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ad663eec48b6484ca04d314230c8836a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/37865 [00:00 1:\n", - " dup_cluster.append([prod_test_record_dict.get(key) for key in dupes])\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "'\\n'.join(pd.Series(dup_cluster).iloc[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 337, - "metadata": {}, - "outputs": [], - "source": [ - "from pandas import option_context\n", - "\n", - "with option_context('display.max_colwidth', 400):\n", - " (pd.DataFrame(dup_cluster).apply(lambda x: '\\n'.join(x.dropna().astype(str)), axis=1)).to_csv('data/dupes.csv')\n" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "1.6537699722699062" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from entity_embed.evaluation import pair_entity_ratio\n", - "\n", - "pair_entity_ratio(len(found_pair_set), len(prod_test_record_dict))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's check now the metrics of the found duplicate pairs:" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [], - "source": [ - "cos_similarity = lambda a, b: np.dot(a, b)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "fields = ['merchant_name',\n", - " 'plaid_merchant',\n", - " 'plaid_category',\n", - " 'avg_tx_amount','count']" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9f613bd140204935badae4f6c7c1a079", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/72542 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
matching_similaritymerchantmatched_merchant
00.915286{'merchant_name': 'carroll county', 'plaid_mer...(carroll emc, -24.0, 26)
10.868136{'merchant_name': 'express fuels', 'plaid_merc...(express nails, -70.0, 14)
20.905586{'merchant_name': 'white', 'plaid_merchant': '...(white hart inn, -8.0, 1)
30.961777{'merchant_name': 'poke city', 'plaid_merchant...(poke poku, -24.0, 7)
40.977440{'merchant_name': 'new china fun', 'plaid_merc...(new china restaurant, -22.0, 323)
............
725370.865486{'merchant_name': 'los', 'plaid_merchant': 'na...(los cabos, -60.0, 35)
725380.926597{'merchant_name': 'mega bev', 'plaid_merchant'...(mega liquor, -22.0, 121)
725390.902664{'merchant_name': 'village inn pizza', 'plaid_...(village restaurant, -20.0, 20)
725400.882056{'merchant_name': 'el amigo', 'plaid_merchant'...(el sol, -20.0, 19)
725410.859562{'merchant_name': 'campus bo', 'plaid_merchant...(campus bookstor building, -236.0, 5)
\n", - "

72542 rows × 3 columns

\n", - "" - ], - "text/plain": [ - " matching_similarity merchant \\\n", - "0 0.915286 {'merchant_name': 'carroll county', 'plaid_mer... \n", - "1 0.868136 {'merchant_name': 'express fuels', 'plaid_merc... \n", - "2 0.905586 {'merchant_name': 'white', 'plaid_merchant': '... \n", - "3 0.961777 {'merchant_name': 'poke city', 'plaid_merchant... \n", - "4 0.977440 {'merchant_name': 'new china fun', 'plaid_merc... \n", - "... ... ... \n", - "72537 0.865486 {'merchant_name': 'los', 'plaid_merchant': 'na... \n", - "72538 0.926597 {'merchant_name': 'mega bev', 'plaid_merchant'... \n", - "72539 0.902664 {'merchant_name': 'village inn pizza', 'plaid_... \n", - "72540 0.882056 {'merchant_name': 'el amigo', 'plaid_merchant'... \n", - "72541 0.859562 {'merchant_name': 'campus bo', 'plaid_merchant... \n", - "\n", - " matched_merchant \n", - "0 (carroll emc, -24.0, 26) \n", - "1 (express nails, -70.0, 14) \n", - "2 (white hart inn, -8.0, 1) \n", - "3 (poke poku, -24.0, 7) \n", - "4 (new china restaurant, -22.0, 323) \n", - "... ... \n", - "72537 (los cabos, -60.0, 35) \n", - "72538 (mega liquor, -22.0, 121) \n", - "72539 (village restaurant, -20.0, 20) \n", - "72540 (el sol, -20.0, 19) \n", - "72541 (campus bookstor building, -236.0, 5) \n", - "\n", - "[72542 rows x 3 columns]" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from tqdm.notebook import tqdm\n", - "duplicates = pd.DataFrame(columns=['matching_similarity','merchant','matched_merchant'])\n", - "for (id_left, id_right, similarity) in tqdm(list(found_pair_set)):\n", - " merchant_a = utils.subdict(prod_test_record_dict[id_left], fields)\n", - " merchant_b = utils.subdict(prod_test_record_dict[id_right], fields)\n", - " duplicates = duplicates.append({'matching_similarity':similarity,'merchant':merchant_a,'matched_merchant':(merchant_b['merchant_name'],merchant_b['avg_tx_amount'],merchant_b['count'])},ignore_index=True)\n", - "\n", - "duplicates" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/adnanshahzada/opt/miniconda3/envs/entity-embed-env/lib/python3.8/site-packages/pandas/core/generic.py:5516: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " self[name] = value\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
matching_similaritymatched_merchant
merchant
{'merchant_name': \"america''s best\", 'plaid_merchant': \"america''s best\", 'plaid_category': 'food and drink restaurants ', 'avg_tx_amount': '-107.0', 'count': '1341'}[0.931][(america''s best wings, -22.0, 281)]
{'merchant_name': \"auntie anne''s\", 'plaid_merchant': \"auntie anne''s\", 'plaid_category': 'food and drink restaurants ', 'avg_tx_amount': '-10.0', 'count': '22971'}[0.981, 0.978][(auntie annes, -11.0, 17), (auntie annies, -8...
{'merchant_name': \"bahama buck''s\", 'plaid_merchant': \"bahama buck''s tx\", 'plaid_category': 'shops clothing and accessories ', 'avg_tx_amount': '-10.0', 'count': '39'}[0.952][(bahama bucks, -11.0, 437)]
{'merchant_name': \"baker''s dozen\", 'plaid_merchant': \"baker''s dozen\", 'plaid_category': 'shops food and beverage store ', 'avg_tx_amount': '-12.0', 'count': '4'}[0.901][(baker''s iga, -24.0, 21)]
{'merchant_name': \"bill''s liquor\", 'plaid_merchant': \"bill''s liquor\", 'plaid_category': 'shops food and beverage store ', 'avg_tx_amount': '-23.0', 'count': '19'}[0.921][(bill''s superette, -15.0, 232)]
.........
{'merchant_name': 'zoom management', 'plaid_merchant': 'nan', 'plaid_category': 'service ', 'avg_tx_amount': '-30.0', 'count': '22'}[0.903, 0.962, 0.935][(zoom tan, -23.0, 208), (zoom mart, -19.0, 14...
{'merchant_name': 'zoom mart', 'plaid_merchant': 'nan', 'plaid_category': 'service ', 'avg_tx_amount': '-19.0', 'count': '14'}[0.95, 0.926][(zoom.us, -16.0, 513), (zoom tan, -23.0, 208)]
{'merchant_name': 'zoom tan', 'plaid_merchant': 'zoom tan', 'plaid_category': 'service personal care ', 'avg_tx_amount': '-23.0', 'count': '208'}[0.927][(zoom.us, -16.0, 513)]
{'merchant_name': 'zoom', 'plaid_merchant': 'nan', 'plaid_category': 'service ', 'avg_tx_amount': '-19.0', 'count': '99'}[0.964, 0.985, 0.966, 0.933, 0.922, 0.969][(zoom.us, -16.0, 513), (zoom mart, -19.0, 14)...
{'merchant_name': 'zt', 'plaid_merchant': 'nan', 'plaid_category': 'transfer debit ', 'avg_tx_amount': '-11.0', 'count': '139'}[0.98][(zt 645, -12.0, 14)]
\n", - "

9918 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " matching_similarity \\\n", - "merchant \n", - "{'merchant_name': \"america''s best\", 'plaid_mer... [0.931] \n", - "{'merchant_name': \"auntie anne''s\", 'plaid_merc... [0.981, 0.978] \n", - "{'merchant_name': \"bahama buck''s\", 'plaid_merc... [0.952] \n", - "{'merchant_name': \"baker''s dozen\", 'plaid_merc... [0.901] \n", - "{'merchant_name': \"bill''s liquor\", 'plaid_merc... [0.921] \n", - "... ... \n", - "{'merchant_name': 'zoom management', 'plaid_mer... [0.903, 0.962, 0.935] \n", - "{'merchant_name': 'zoom mart', 'plaid_merchant'... [0.95, 0.926] \n", - "{'merchant_name': 'zoom tan', 'plaid_merchant':... [0.927] \n", - "{'merchant_name': 'zoom', 'plaid_merchant': 'na... [0.964, 0.985, 0.966, 0.933, 0.922, 0.969] \n", - "{'merchant_name': 'zt', 'plaid_merchant': 'nan'... [0.98] \n", - "\n", - " matched_merchant \n", - "merchant \n", - "{'merchant_name': \"america''s best\", 'plaid_mer... [(america''s best wings, -22.0, 281)] \n", - "{'merchant_name': \"auntie anne''s\", 'plaid_merc... [(auntie annes, -11.0, 17), (auntie annies, -8... \n", - "{'merchant_name': \"bahama buck''s\", 'plaid_merc... [(bahama bucks, -11.0, 437)] \n", - "{'merchant_name': \"baker''s dozen\", 'plaid_merc... [(baker''s iga, -24.0, 21)] \n", - "{'merchant_name': \"bill''s liquor\", 'plaid_merc... [(bill''s superette, -15.0, 232)] \n", - "... ... \n", - "{'merchant_name': 'zoom management', 'plaid_mer... [(zoom tan, -23.0, 208), (zoom mart, -19.0, 14... \n", - "{'merchant_name': 'zoom mart', 'plaid_merchant'... [(zoom.us, -16.0, 513), (zoom tan, -23.0, 208)] \n", - "{'merchant_name': 'zoom tan', 'plaid_merchant':... [(zoom.us, -16.0, 513)] \n", - "{'merchant_name': 'zoom', 'plaid_merchant': 'na... [(zoom.us, -16.0, 513), (zoom mart, -19.0, 14)... \n", - "{'merchant_name': 'zt', 'plaid_merchant': 'nan'... [(zt 645, -12.0, 14)] \n", - "\n", - "[9918 rows x 2 columns]" - ] - }, - "execution_count": 45, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "duplicates_strong = duplicates[duplicates.matching_similarity>0.9]\n", - "duplicates_strong.matching_similarity = duplicates_strong.matching_similarity.apply(lambda x: round(x,3))\n", - "duplicates_strong.merchant = duplicates_strong.merchant.astype(str)\n", - "result =duplicates_strong.groupby('merchant').agg(lambda x: list(x))\n", - "result.to_csv('data/duplicate_merchants_ct.csv')\n", - "result" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[('Baho Convenience Store', False),\n", - " ('Hilton Payroll', False),\n", - " ('Fast Shop', False),\n", - " ('Imperial Mart', False),\n", - " ('Bitterroot Beanery', False),\n", - " ('Piedmont Natural Gas', False),\n", - " ('The Island Shoppe', False),\n", - " ('Americas Best Wings', False),\n", - " ('Rockland Nails', False),\n", - " ('Irobot Corporation', False),\n", - " ('Khalil`s Food & Liquor', False),\n", - " ('Quality Auto Repair', False),\n", - " ('Jackpot Mini Mart', False),\n", - " ('Westside Convenience', False),\n", - " ('H And M Mini Sto', False),\n", - " ('Patel Corner Pantry', False),\n", - " ('Guys Pizza Downtown', False),\n", - " ('Xingyu Restaurant Inc', False),\n", - " ('5801 Video Lounge & Caf', False),\n", - " ('Just Salad', True)]" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "idx_a += 20\n", - "idx_b += 20\n", - "[(r,r in merchants_list) for r in dataset.iloc[idx_a:idx_b,:]['merchant_name'].tolist()]" - ] - }, - { - "cell_type": "code", - "execution_count": 144, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(0.7512554540215691, 0.12059282165133735)" - ] - }, - "execution_count": 144, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from entity_embed.evaluation import precision_and_recall\n", - "\n", - "precision_and_recall(found_pair_set, datamodule.test_pos_pair_set)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Same numbers of the `trainer.test`, so our manual testing is fine." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, we can check the false positives and negatives to see if they're really difficult:" - ] - }, - { - "cell_type": "code", - "execution_count": 145, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "12086" - ] - }, - "execution_count": 145, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "false_positives = list(found_pair_set - datamodule.test_pos_pair_set)\n", - "len(false_positives)" - ] - }, - { - "cell_type": "code", - "execution_count": 146, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "266186" - ] - }, - "execution_count": 146, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "false_negatives = list(datamodule.test_pos_pair_set - found_pair_set)\n", - "len(false_negatives)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "true_positives = list(found_pair_set - datamodule.test_pos_pair_set)\n", - "len(false_positives)\n", - "for (id_left, id_right) in list(found_pair_set)[:10]:\n", - " display(\n", - " (\n", - " record_dict[id_left],\n", - " record_dict[id_right],\n", - " cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),\n", - " utils.subdict(record_dict[id_left], field_list), utils.subdict(record_dict[id_right], field_list)\n", - " )\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 147, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(0.87486815,\n", - " {'alias': 'r&b tea', 'plaid_merchant': 'r&b tea', 'avg_tx_amount': '-6.35'},\n", - " {'alias': 'r & gs food basket',\n", - " 'plaid_merchant': 'r & gs food basket',\n", - " 'avg_tx_amount': '-21.83'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.8825392,\n", - " {'alias': 'country wide insurance',\n", - " 'plaid_merchant': 'country wide insurance',\n", - " 'avg_tx_amount': '-284.13'},\n", - " {'alias': 'country convenie',\n", - " 'plaid_merchant': 'country convenie',\n", - " 'avg_tx_amount': '-14.1'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.905407,\n", - " {'alias': 'china town 1', 'plaid_merchant': 'nan', 'avg_tx_amount': '-25.0'},\n", - " {'alias': 'china delight chinese',\n", - " 'plaid_merchant': 'china delight chinese',\n", - " 'avg_tx_amount': '-21.56'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.70510805,\n", - " {'alias': 'blue bay',\n", - " 'plaid_merchant': 'blue bay',\n", - " 'avg_tx_amount': '-12.96'},\n", - " {'alias': 'blue moon tap house',\n", - " 'plaid_merchant': 'blue moon tap house',\n", - " 'avg_tx_amount': '-26.07'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9555845,\n", - " {'alias': 'lincoln highway',\n", - " 'plaid_merchant': 'lincoln highway',\n", - " 'avg_tx_amount': '-11.24'},\n", - " {'alias': 'lincoln c mart',\n", - " 'plaid_merchant': 'lincoln c mart',\n", - " 'avg_tx_amount': '-24.99'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.95284754,\n", - " {'alias': 'quick pick grocery ltd',\n", - " 'plaid_merchant': 'nan',\n", - " 'avg_tx_amount': '-8.98'},\n", - " {'alias': 'quick pick atlanta',\n", - " 'plaid_merchant': 'quick pick',\n", - " 'avg_tx_amount': '-5.26'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.8914383,\n", - " {'alias': 'china gourmet house',\n", - " 'plaid_merchant': 'china gourmet house',\n", - " 'avg_tx_amount': '-23.85'},\n", - " {'alias': 'china town 1', 'plaid_merchant': 'nan', 'avg_tx_amount': '-36.75'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9162289,\n", - " {'alias': 'taco heads fort worth',\n", - " 'plaid_merchant': 'taco heads fort worth',\n", - " 'avg_tx_amount': '-18.76'},\n", - " {'alias': 'taco monster',\n", - " 'plaid_merchant': 'taco monster',\n", - " 'avg_tx_amount': '-29.71'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.95120347,\n", - " {'alias': 'noodles pho u',\n", - " 'plaid_merchant': 'noodles pho u',\n", - " 'avg_tx_amount': '-33.58'},\n", - " {'alias': 'noodles & company',\n", - " 'plaid_merchant': 'noodles & company',\n", - " 'avg_tx_amount': '-30.15'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9241347,\n", - " {'alias': 'noodles and dumplings',\n", - " 'plaid_merchant': 'noodles and dumplings',\n", - " 'avg_tx_amount': '-48.14'},\n", - " {'alias': 'noodles & company',\n", - " 'plaid_merchant': 'noodles & company',\n", - " 'avg_tx_amount': '-8.7'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "for (id_left, id_right) in false_positives[:10]:\n", - " display(\n", - " (\n", - " cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),\n", - " utils.subdict(record_dict[id_left], field_list), utils.subdict(record_dict[id_right], field_list)\n", - " )\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 149, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(0.99850756,\n", - " {'alias': 'kwik trip',\n", - " 'plaid_merchant': 'kwik trip',\n", - " 'avg_tx_amount': '-20.42'},\n", - " {'alias': 'kwik trip',\n", - " 'plaid_merchant': 'kwik trip',\n", - " 'avg_tx_amount': '-18.5'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.99695677,\n", - " {'alias': 'rent-a-center',\n", - " 'plaid_merchant': 'rent-a-center',\n", - " 'avg_tx_amount': '-121.85'},\n", - " {'alias': 'rent-a-center',\n", - " 'plaid_merchant': 'rent-a-center',\n", - " 'avg_tx_amount': '-38.35'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9966733,\n", - " {'alias': 'kwik trip',\n", - " 'plaid_merchant': 'kwik trip',\n", - " 'avg_tx_amount': '-17.18'},\n", - " {'alias': 'kwik trip',\n", - " 'plaid_merchant': 'kwik trip',\n", - " 'avg_tx_amount': '-26.18'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.97021323,\n", - " {'alias': 'cardtronics', 'plaid_merchant': 'nan', 'avg_tx_amount': '-43.5'},\n", - " {'alias': 'cardtronics',\n", - " 'plaid_merchant': 'cardtronics',\n", - " 'avg_tx_amount': '-41.16'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9971533,\n", - " {'alias': 'stop & shop',\n", - " 'plaid_merchant': 'stop & shop',\n", - " 'avg_tx_amount': '-23.7'},\n", - " {'alias': 'stop & shop',\n", - " 'plaid_merchant': 'stop & shop',\n", - " 'avg_tx_amount': '-7.62'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.99700713,\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-18.72'},\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-20.54'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9975728,\n", - " {'alias': \"dave & buster''s\",\n", - " 'plaid_merchant': 'dave & busters',\n", - " 'avg_tx_amount': '-35.75'},\n", - " {'alias': \"dave & buster''s\",\n", - " 'plaid_merchant': 'dave & busters',\n", - " 'avg_tx_amount': '-77.44'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9961077,\n", - " {'alias': 'kwik trip',\n", - " 'plaid_merchant': 'kwik trip',\n", - " 'avg_tx_amount': '-14.02'},\n", - " {'alias': 'kwik trip',\n", - " 'plaid_merchant': 'kwik trip',\n", - " 'avg_tx_amount': '-9.2'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.99612856,\n", - " {'alias': \"dave & buster''s\",\n", - " 'plaid_merchant': 'dave & busters',\n", - " 'avg_tx_amount': '-46.28'},\n", - " {'alias': \"dave & buster''s\",\n", - " 'plaid_merchant': \"dave & buster''s\",\n", - " 'avg_tx_amount': '-33.63'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9968643,\n", - " {'alias': 'tractor supply',\n", - " 'plaid_merchant': 'tractor supply',\n", - " 'avg_tx_amount': '-28.6'},\n", - " {'alias': 'tractor supply',\n", - " 'plaid_merchant': 'tractor supply',\n", - " 'avg_tx_amount': '-54.1'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9972248,\n", - " {'alias': 'stop & shop',\n", - " 'plaid_merchant': 'stop & shop',\n", - " 'avg_tx_amount': '-32.71'},\n", - " {'alias': 'stop & shop',\n", - " 'plaid_merchant': 'stop & shop',\n", - " 'avg_tx_amount': '-10.57'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9969662,\n", - " {'alias': 'cardtronics',\n", - " 'plaid_merchant': 'cardtronics',\n", - " 'avg_tx_amount': '-184.65'},\n", - " {'alias': 'cardtronics',\n", - " 'plaid_merchant': 'cardtronics',\n", - " 'avg_tx_amount': '-52.5'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9916726,\n", - " {'alias': 'urban air killeen',\n", - " 'plaid_merchant': 'urban air killeen',\n", - " 'avg_tx_amount': '-36.15'},\n", - " {'alias': 'urban air',\n", - " 'plaid_merchant': 'urban air',\n", - " 'avg_tx_amount': '-19.75'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9962817,\n", - " {'alias': 'stop & shop',\n", - " 'plaid_merchant': 'stop & shop',\n", - " 'avg_tx_amount': '-98.06'},\n", - " {'alias': 'stop & shop',\n", - " 'plaid_merchant': 'stop & shop',\n", - " 'avg_tx_amount': '-10.69'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.997394,\n", - " {'alias': 'stop & shop',\n", - " 'plaid_merchant': 'stop & shop',\n", - " 'avg_tx_amount': '-5.0'},\n", - " {'alias': 'stop & shop',\n", - " 'plaid_merchant': 'stop & shop',\n", - " 'avg_tx_amount': '-7.48'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.99658275,\n", - " {'alias': 'rent-a-center',\n", - " 'plaid_merchant': 'rent-a-center',\n", - " 'avg_tx_amount': '-152.17'},\n", - " {'alias': 'rent-a-center',\n", - " 'plaid_merchant': 'rent-a-center',\n", - " 'avg_tx_amount': '-344.69'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.99650383,\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-24.02'},\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-6.12'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.997722,\n", - " {'alias': 'krispy kreme',\n", - " 'plaid_merchant': 'krispy kreme',\n", - " 'avg_tx_amount': '-12.15'},\n", - " {'alias': 'krispy kreme',\n", - " 'plaid_merchant': 'krispy kreme',\n", - " 'avg_tx_amount': '-11.94'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.99689335,\n", - " {'alias': 'rent-a-center',\n", - " 'plaid_merchant': 'rent-a-center',\n", - " 'avg_tx_amount': '-8.94'},\n", - " {'alias': 'rent-a-center',\n", - " 'plaid_merchant': 'rent-a-center',\n", - " 'avg_tx_amount': '-58.77'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9976802,\n", - " {'alias': 'cardtronics',\n", - " 'plaid_merchant': 'cardtronics',\n", - " 'avg_tx_amount': '-40.0'},\n", - " {'alias': 'cardtronics',\n", - " 'plaid_merchant': 'cardtronics',\n", - " 'avg_tx_amount': '-52.72'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.99777186,\n", - " {'alias': 'chipotle mexican grill',\n", - " 'plaid_merchant': 'chipotle mexican grill',\n", - " 'avg_tx_amount': '-14.17'},\n", - " {'alias': 'chipotle mexican grill',\n", - " 'plaid_merchant': 'chipotle mexican grill',\n", - " 'avg_tx_amount': '-12.07'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9955483,\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-12.93'},\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-23.91'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9961946,\n", - " {'alias': \"dave & buster''s\",\n", - " 'plaid_merchant': \"dave & buster''s\",\n", - " 'avg_tx_amount': '-38.62'},\n", - " {'alias': \"dave & buster''s\",\n", - " 'plaid_merchant': \"dave & buster''s\",\n", - " 'avg_tx_amount': '-70.86'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9975935,\n", - " {'alias': 'holiday inn',\n", - " 'plaid_merchant': 'holiday inn',\n", - " 'avg_tx_amount': '-296.98'},\n", - " {'alias': 'holiday inn',\n", - " 'plaid_merchant': 'holiday inn',\n", - " 'avg_tx_amount': '-39.1'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9973572,\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-48.81'},\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-8.24'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.99457943,\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-12.56'},\n", - " {'alias': 'wawa', 'plaid_merchant': 'wawa', 'avg_tx_amount': '-36.24'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9984803,\n", - " {'alias': 'kwik trip',\n", - " 'plaid_merchant': 'kwik trip',\n", - " 'avg_tx_amount': '-25.21'},\n", - " {'alias': 'kwik trip',\n", - " 'plaid_merchant': 'kwik trip',\n", - " 'avg_tx_amount': '-28.34'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.97368807,\n", - " {'alias': 'cardtronics', 'plaid_merchant': 'nan', 'avg_tx_amount': '-51.75'},\n", - " {'alias': 'cardtronics',\n", - " 'plaid_merchant': 'cardtronics',\n", - " 'avg_tx_amount': '-89.69'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9977134,\n", - " {'alias': 'gulf oil',\n", - " 'plaid_merchant': 'gulf oil',\n", - " 'avg_tx_amount': '-19.08'},\n", - " {'alias': 'gulf oil',\n", - " 'plaid_merchant': 'gulf oil',\n", - " 'avg_tx_amount': '-34.91'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(0.9969918,\n", - " {'alias': 'marshalls',\n", - " 'plaid_merchant': 'marshalls',\n", - " 'avg_tx_amount': '-76.19'},\n", - " {'alias': 'marshalls',\n", - " 'plaid_merchant': 'marshalls',\n", - " 'avg_tx_amount': '-43.74'})" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "for (id_left, id_right) in false_negatives[:30]:\n", - " display(\n", - " (\n", - " cos_similarity(test_vector_dict[id_left], test_vector_dict[id_right]),\n", - " utils.subdict(record_dict[id_left], field_list), utils.subdict(record_dict[id_right], field_list)\n", - " )\n", - " )" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.13", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.13" - }, - "vscode": { - "interpreter": { - "hash": "e3bb7d66ba21cc372144d4e6f3a54e31b034566124f778bc0ae068d657400bc6" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 78ff0c218233c966fe7d8e0429a2ea2aa1b490d7 Mon Sep 17 00:00:00 2001 From: Hannah Date: Thu, 27 Oct 2022 10:47:03 +0100 Subject: [PATCH 04/17] Update to use faiss --- entity_embed/__init__.py | 2 +- .../data_utils/field_config_parser.py | 36 ++- entity_embed/data_utils/numericalizer.py | 5 +- entity_embed/early_stopping.py | 4 +- entity_embed/entity_embed.py | 17 +- entity_embed/indexes.py | 269 ++++++++++++------ entity_embed/models.py | 2 +- requirements.txt | 12 +- tests/test_data_utils_helpers.py | 3 +- 9 files changed, 243 insertions(+), 107 deletions(-) diff --git a/entity_embed/__init__.py b/entity_embed/__init__.py index 0ebd1b7..9ce1ae6 100644 --- a/entity_embed/__init__.py +++ b/entity_embed/__init__.py @@ -2,7 +2,7 @@ import logging # libgomp issue, must import n2 before torch. See: https://github.com/kakao/n2/issues/42 -import n2 # noqa: F401 +# import n2 # noqa: F401 from .data_modules import * # noqa: F401, F403 from .data_utils.field_config_parser import FieldConfigDictParser # noqa: F401 diff --git a/entity_embed/data_utils/field_config_parser.py b/entity_embed/data_utils/field_config_parser.py index e548235..ffecfa5 100644 --- a/entity_embed/data_utils/field_config_parser.py +++ b/entity_embed/data_utils/field_config_parser.py @@ -2,7 +2,9 @@ import logging from importlib import import_module -from torchtext.vocab import Vocab +from torch import Tensor, nn +from torchtext.vocab import Vocab, Vectors, build_vocab_from_iterator +from torchtext.vocab import vocab as tt_vocab from .numericalizer import ( AVAILABLE_VOCABS, @@ -66,6 +68,7 @@ def _parse_field_config(cls, field, field_config, record_list): alphabet = field_config.get("alphabet", DEFAULT_ALPHABET) max_str_len = field_config.get("max_str_len") vocab = None + vector_tensor = None # Check if there's a key defined on the field_config, # useful when we want to have multiple FieldConfig for the same field @@ -92,8 +95,33 @@ def _parse_field_config(cls, field, field_config, record_list): "field_config if you wish to use a override " "an field name." ) - vocab = Vocab(vocab_counter) - vocab.load_vectors(vocab_type) + + vectors = Vectors(vocab_type, cache=".vector_cache") + + vocab = tt_vocab(vocab_counter) + + vectors = [vectors] + tot_dim = sum(v.dim for v in vectors) # 100 + vector_tensor = Tensor(len(vocab), tot_dim) + + for i, token in enumerate(vocab.get_itos()): + start_dim = 0 + for v in vectors: + end_dim = start_dim + v.dim + vector_tensor[i][start_dim:end_dim] = v[token.strip()] + start_dim = end_dim + assert start_dim == tot_dim + + print(f"Vector tensor shape: {vector_tensor.shape}") + + print(len(vector_tensor)) + print(len(vocab)) + + print(nn.Embedding.from_pretrained(vector_tensor)) + + # pretrained_vectors = vectors # torchtext.vocab.FastText("en") + + # vocab.load_vectors(vectors) # Compute max_str_len if necessary if field_type in (FieldType.STRING, FieldType.MULTITOKEN) and (max_str_len is None): @@ -128,6 +156,7 @@ def _parse_field_config(cls, field, field_config, record_list): alphabet=alphabet, max_str_len=max_str_len, vocab=vocab, + vector_tensor=vector_tensor, n_channels=n_channels, embed_dropout_p=embed_dropout_p, use_attention=use_attention, @@ -143,6 +172,7 @@ def _build_field_numericalizer(cls, field, field_config: FieldConfig): FieldType.SEMANTIC_STRING: SemanticStringNumericalizer, FieldType.SEMANTIC_MULTITOKEN: SemanticMultitokenNumericalizer, } + print(field_type_to_numericalizer_cls) numericalizer_cls = field_type_to_numericalizer_cls.get(field_type) if numericalizer_cls is None: raise ValueError(f"Unexpected field_type={field_type}") # pragma: no cover diff --git a/entity_embed/data_utils/numericalizer.py b/entity_embed/data_utils/numericalizer.py index dbdbdc8..2960e70 100644 --- a/entity_embed/data_utils/numericalizer.py +++ b/entity_embed/data_utils/numericalizer.py @@ -7,7 +7,7 @@ import numpy as np import regex import torch -from torchtext.vocab import Vocab +from torchtext.vocab import Vocab, Vectors logger = logging.getLogger(__name__) @@ -45,6 +45,7 @@ class FieldConfig: alphabet: List[str] max_str_len: int vocab: Vocab + vector_tensor: torch.Tensor n_channels: int embed_dropout_p: float use_attention: bool @@ -60,7 +61,7 @@ def __repr__(self): repr_dict = {} for k, v in self.__dict__.items(): if isinstance(v, Callable): - repr_dict[k] = f"{inspect.getmodule(v).__name__}.{v.__name__}" + repr_dict[k] = f"{inspect.getmodule(v).__name__}.{getattr(v, '.__name__', repr(v))}" else: repr_dict[k] = v return "{cls}({attrs})".format( diff --git a/entity_embed/early_stopping.py b/entity_embed/early_stopping.py index cf80e49..5e99b98 100644 --- a/entity_embed/early_stopping.py +++ b/entity_embed/early_stopping.py @@ -53,8 +53,8 @@ def __init__( save_top_k=save_top_k, save_weights_only=save_weights_only, mode=mode, - period=period, - prefix=prefix, + # period=period, + # prefix=prefix, ) self.min_epochs = min_epochs diff --git a/entity_embed/entity_embed.py b/entity_embed/entity_embed.py index c4b8d95..adcbf44 100644 --- a/entity_embed/entity_embed.py +++ b/entity_embed/entity_embed.py @@ -12,12 +12,17 @@ from .data_utils.datasets import RecordDataset from .early_stopping import EarlyStoppingMinEpochs, ModelCheckpointMinEpochs from .evaluation import f1_score, pair_entity_ratio, precision_and_recall -from .indexes import ANNEntityIndex, ANNLinkageIndex + +from .indexes import ANNEntityIndex # , ANNLinkageIndex from .models import BlockerNet logger = logging.getLogger(__name__) +def hannah_test(n): + return n * 2 + + class _BaseEmbed(pl.LightningModule): def __init__( self, @@ -42,11 +47,12 @@ def __init__( self.record_numericalizer = record_numericalizer for field_config in self.record_numericalizer.field_config_dict.values(): vocab = field_config.vocab + vector_tensor = field_config.vector_tensor if vocab: # We can assume that there's only one vocab type across the # whole field_config_dict, so we can stop the loop once we've # found a field_config with a vocab - valid_embedding_size = vocab.vectors.size(1) + valid_embedding_size = vector_tensor.size(1) if valid_embedding_size != embedding_size: raise ValueError( f"Invalid embedding_size={embedding_size}. " @@ -99,7 +105,7 @@ def training_step(self, batch, batch_idx): self.log("train_loss", loss) return loss - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx=None): self.blocker_net.fix_pool_weights() self.log_dict( { @@ -189,7 +195,7 @@ def fit( "max_epochs": max_epochs, "check_val_every_n_epoch": check_val_every_n_epoch, "callbacks": [early_stop_callback, checkpoint_callback], - "reload_dataloaders_every_epoch": True, # for shuffling ClusterDataset every epoch + "reload_dataloaders_every_n_epochs": 10, # for shuffling ClusterDataset every epoch } if use_gpu: trainer_args["gpus"] = 1 @@ -205,8 +211,9 @@ def fit( "TensorBoardLogger or omit both to disable it" ) trainer = pl.Trainer(**trainer_args) + print("Trainer done") trainer.fit(self, datamodule) - + print("Model fit") logger.info( "Loading the best validation model from " f"{trainer.checkpoint_callback.best_model_path}..." diff --git a/entity_embed/indexes.py b/entity_embed/indexes.py index c59151a..7045372 100644 --- a/entity_embed/indexes.py +++ b/entity_embed/indexes.py @@ -1,6 +1,7 @@ import logging +import faiss -from n2 import HnswIndex +# from n2 import HnswIndex from .helpers import build_index_build_kwargs, build_index_search_kwargs @@ -9,13 +10,24 @@ class ANNEntityIndex: def __init__(self, embedding_size): - self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") + self.approx_knn_index = faiss.index_factory( + embedding_size, "Flat", faiss.METRIC_INNER_PRODUCT + ) + # self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") self.vector_idx_to_id = None self.is_built = False + print(self.approx_knn_index.is_trained) def insert_vector_dict(self, vector_dict): for vector in vector_dict.values(): - self.approx_knn_index.add_data(vector) + # self.approx_knn_index.add_data(vector) + + # print(vector.dtype) + # print(vector.shape) + # print(repr(vector)) + vector = vector.reshape(1, 100) + # vector = faiss.normalize_L2(vector) + self.approx_knn_index.add(vector) self.vector_idx_to_id = dict(enumerate(vector_dict.keys())) def build( @@ -25,9 +37,11 @@ def build( if self.vector_idx_to_id is None: raise ValueError("Please call insert_vector_dict first") - actual_index_build_kwargs = build_index_build_kwargs(index_build_kwargs) - self.approx_knn_index.build(**actual_index_build_kwargs) + # actual_index_build_kwargs = build_index_build_kwargs(index_build_kwargs) + # self.approx_knn_index.build(**actual_index_build_kwargs) + print(self.approx_knn_index.ntotal) self.is_built = True + # faiss.write_index(self.approx_knn_index, "vector.index") def search_pairs(self, k, sim_threshold, index_search_kwargs=None): if not self.is_built: @@ -36,105 +50,188 @@ def search_pairs(self, k, sim_threshold, index_search_kwargs=None): raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") logger.debug("Searching on approx_knn_index...") - + print(sim_threshold) distance_threshold = 1 - sim_threshold + print(distance_threshold) index_search_kwargs = build_index_search_kwargs(index_search_kwargs) - neighbor_and_distance_list_of_list = self.approx_knn_index.batch_search_by_ids( - item_ids=self.vector_idx_to_id.keys(), - k=k, - include_distances=True, - **index_search_kwargs, - ) - - logger.debug("Search on approx_knn_index done, building found_pair_set now...") found_pair_set = set() - for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): + item_ids = self.vector_idx_to_id # .keys() + # print(item_ids) + for i in item_ids: + vector = self.approx_knn_index.reconstruct(i).reshape(1, 100) + # print(vector.shape) + # print(i) + similarities, neighbours = self.approx_knn_index.search(vector, k=k) left_id = self.vector_idx_to_id[i] - for j, distance in neighbor_distance_list: - if i != j and distance <= distance_threshold: + # print(similarities[0]) + for similarity, j in zip(similarities[0], neighbours[0]): + # print(j) + if i != j and similarity >= sim_threshold: right_id = self.vector_idx_to_id[j] # must use sorted to always have smaller id on left of pair tuple pair = tuple(sorted([left_id, right_id])) found_pair_set.add(pair) - logger.debug( - f"Building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." - ) - - return found_pair_set - - -class ANNLinkageIndex: - def __init__(self, embedding_size): - self.left_index = ANNEntityIndex(embedding_size) - self.right_index = ANNEntityIndex(embedding_size) - - def insert_vector_dict(self, left_vector_dict, right_vector_dict): - self.left_index.insert_vector_dict(vector_dict=left_vector_dict) - self.right_index.insert_vector_dict(vector_dict=right_vector_dict) - - def build( - self, - index_build_kwargs=None, - ): - self.left_index.build(index_build_kwargs=index_build_kwargs) - self.right_index.build(index_build_kwargs=index_build_kwargs) - - def search_pairs( - self, - k, - sim_threshold, - left_vector_dict, - right_vector_dict, - left_source, - index_search_kwargs=None, - ): - if not self.left_index.is_built or not self.right_index.is_built: - raise ValueError("Please call build first") - if sim_threshold > 1 or sim_threshold < 0: - raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") - - index_search_kwargs = build_index_search_kwargs(index_search_kwargs) - distance_threshold = 1 - sim_threshold - all_pair_set = set() - - for dataset_name, index, vector_dict, other_index in [ - (left_source, self.left_index, right_vector_dict, self.right_index), - (None, self.right_index, left_vector_dict, self.left_index), - ]: - logger.debug(f"Searching on approx_knn_index of dataset_name={dataset_name}...") + print(f"found_pair_set: {len(found_pair_set)}") + logger.debug("Search on approx_knn_index done, building found_pair_set now...") - neighbor_and_distance_list_of_list = index.approx_knn_index.batch_search_by_vectors( - vs=vector_dict.values(), k=k, include_distances=True, **index_search_kwargs + if False: + neighbor_and_distance_list_of_list = self.approx_knn_index.batch_search_by_ids( + item_ids=self.vector_idx_to_id.keys(), + k=k, + include_distances=True, + **index_search_kwargs, ) - logger.debug( - f"Search on approx_knn_index of dataset_name={dataset_name}... done, " - "filling all_pair_set now..." - ) + logger.debug("Search on approx_knn_index done, building found_pair_set now...") + found_pair_set = set() for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): - other_id = other_index.vector_idx_to_id[i] + left_id = self.vector_idx_to_id[i] for j, distance in neighbor_distance_list: - if distance <= distance_threshold: # do NOT check for i != j here - id_ = index.vector_idx_to_id[j] - if dataset_name and dataset_name == left_source: - left_id, right_id = (id_, other_id) - else: - left_id, right_id = (other_id, id_) - pair = ( - left_id, - right_id, - ) # do NOT use sorted here, figure out from datasets - all_pair_set.add(pair) - - logger.debug(f"Filling all_pair_set with dataset_name={dataset_name} done.") + if i != j and distance <= distance_threshold: + right_id = self.vector_idx_to_id[j] + # must use sorted to always have smaller id on left of pair tuple + pair = tuple(sorted([left_id, right_id])) + found_pair_set.add(pair) logger.debug( - "All searches done, all_pair_set filled. " - f"Found len(all_pair_set)={len(all_pair_set)} pairs." + f"Building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." ) - return all_pair_set + return found_pair_set + + +# class ANNEntityIndex: +# def __init__(self, embedding_size): +# self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") +# self.vector_idx_to_id = None +# self.is_built = False + +# def insert_vector_dict(self, vector_dict): +# for vector in vector_dict.values(): +# self.approx_knn_index.add_data(vector) +# self.vector_idx_to_id = dict(enumerate(vector_dict.keys())) + +# def build( +# self, +# index_build_kwargs=None, +# ): +# if self.vector_idx_to_id is None: +# raise ValueError("Please call insert_vector_dict first") + +# actual_index_build_kwargs = build_index_build_kwargs(index_build_kwargs) +# self.approx_knn_index.build(**actual_index_build_kwargs) +# self.is_built = True + +# def search_pairs(self, k, sim_threshold, index_search_kwargs=None): +# if not self.is_built: +# raise ValueError("Please call build first") +# if sim_threshold > 1 or sim_threshold < 0: +# raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") + +# logger.debug("Searching on approx_knn_index...") + +# distance_threshold = 1 - sim_threshold + +# index_search_kwargs = build_index_search_kwargs(index_search_kwargs) +# neighbor_and_distance_list_of_list = self.approx_knn_index.batch_search_by_ids( +# item_ids=self.vector_idx_to_id.keys(), +# k=k, +# include_distances=True, +# **index_search_kwargs, +# ) + +# logger.debug("Search on approx_knn_index done, building found_pair_set now...") + +# found_pair_set = set() +# for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): +# left_id = self.vector_idx_to_id[i] +# for j, distance in neighbor_distance_list: +# if i != j and distance <= distance_threshold: +# right_id = self.vector_idx_to_id[j] +# # must use sorted to always have smaller id on left of pair tuple +# pair = tuple(sorted([left_id, right_id])) +# found_pair_set.add(pair) + +# logger.debug( +# f"Building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." +# ) + +# return found_pair_set + + +# class ANNLinkageIndex: +# def __init__(self, embedding_size): +# self.left_index = ANNEntityIndex(embedding_size) +# self.right_index = ANNEntityIndex(embedding_size) + +# def insert_vector_dict(self, left_vector_dict, right_vector_dict): +# self.left_index.insert_vector_dict(vector_dict=left_vector_dict) +# self.right_index.insert_vector_dict(vector_dict=right_vector_dict) + +# def build( +# self, +# index_build_kwargs=None, +# ): +# self.left_index.build(index_build_kwargs=index_build_kwargs) +# self.right_index.build(index_build_kwargs=index_build_kwargs) + +# def search_pairs( +# self, +# k, +# sim_threshold, +# left_vector_dict, +# right_vector_dict, +# left_source, +# index_search_kwargs=None, +# ): +# if not self.left_index.is_built or not self.right_index.is_built: +# raise ValueError("Please call build first") +# if sim_threshold > 1 or sim_threshold < 0: +# raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") + +# index_search_kwargs = build_index_search_kwargs(index_search_kwargs) +# distance_threshold = 1 - sim_threshold +# all_pair_set = set() + +# for dataset_name, index, vector_dict, other_index in [ +# (left_source, self.left_index, right_vector_dict, self.right_index), +# (None, self.right_index, left_vector_dict, self.left_index), +# ]: +# logger.debug(f"Searching on approx_knn_index of dataset_name={dataset_name}...") + +# neighbor_and_distance_list_of_list = index.approx_knn_index.batch_search_by_vectors( +# vs=vector_dict.values(), k=k, include_distances=True, **index_search_kwargs +# ) + +# logger.debug( +# f"Search on approx_knn_index of dataset_name={dataset_name}... done, " +# "filling all_pair_set now..." +# ) + +# for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): +# other_id = other_index.vector_idx_to_id[i] +# for j, distance in neighbor_distance_list: +# if distance <= distance_threshold: # do NOT check for i != j here +# id_ = index.vector_idx_to_id[j] +# if dataset_name and dataset_name == left_source: +# left_id, right_id = (id_, other_id) +# else: +# left_id, right_id = (other_id, id_) +# pair = ( +# left_id, +# right_id, +# ) # do NOT use sorted here, figure out from datasets +# all_pair_set.add(pair) + +# logger.debug(f"Filling all_pair_set with dataset_name={dataset_name} done.") + +# logger.debug( +# "All searches done, all_pair_set filled. " +# f"Found len(all_pair_set)={len(all_pair_set)} pairs." +# ) + +# return all_pair_set diff --git a/entity_embed/models.py b/entity_embed/models.py index 68f8159..e78e735 100644 --- a/entity_embed/models.py +++ b/entity_embed/models.py @@ -58,7 +58,7 @@ def __init__(self, field_config, embedding_size): self.embedding_size = embedding_size self.dense_net = nn.Sequential( - nn.Embedding.from_pretrained(field_config.vocab.vectors), + nn.Embedding.from_pretrained(field_config.vector_tensor), nn.Dropout(p=field_config.embed_dropout_p), ) diff --git a/requirements.txt b/requirements.txt index 199a67b..e25e7b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ click==7.1.2,<8.0 more-itertools>=8.6.0,<9.0 -n2>=0.1.7,<1.2 +# n2>=0.1.7,<1.2 numpy>=1.19.0 ordered-set>=4.0.2 -pytorch_lightning>=1.1.6,<1.3 -pytorch-metric-learning>=0.9.98,<1.0 +pytorch_lightning>=1.1.6 +pytorch-metric-learning>=0.9.98 regex>=2020.11.13 -torch>=1.7.1,<1.9 -torchtext>=0.8,<0.10 -torchvision>=0.8.2<0.10 +torch>=1.7.1 +torchtext>=0.8 +torchvision>=0.8.2 tqdm>=4.53.0 diff --git a/tests/test_data_utils_helpers.py b/tests/test_data_utils_helpers.py index 830717f..92ffe85 100644 --- a/tests/test_data_utils_helpers.py +++ b/tests/test_data_utils_helpers.py @@ -3,7 +3,8 @@ import tempfile import mock -import n2 # noqa: F401 + +# import n2 # noqa: F401 import pytest from entity_embed.data_utils.field_config_parser import FieldConfigDictParser from entity_embed.data_utils.numericalizer import FieldConfig, FieldType, RecordNumericalizer From 698d85ab013a54ab69b6fef08570c674ba087bb8 Mon Sep 17 00:00:00 2001 From: Hannah Date: Thu, 27 Oct 2022 11:08:48 +0100 Subject: [PATCH 05/17] Remove some print statements --- .../data_utils/field_config_parser.py | 28 ++++++++----------- entity_embed/data_utils/numericalizer.py | 2 +- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/entity_embed/data_utils/field_config_parser.py b/entity_embed/data_utils/field_config_parser.py index 7e247f8..9aa7284 100644 --- a/entity_embed/data_utils/field_config_parser.py +++ b/entity_embed/data_utils/field_config_parser.py @@ -3,7 +3,8 @@ from importlib import import_module from torch import Tensor, nn -from torchtext.vocab import Vocab, Vectors +from torchtext.vocab import Vocab, Vectors, FastText +from torchtext.vocab import vocab as factory_vocab from .numericalizer import ( AVAILABLE_VOCABS, @@ -95,10 +96,16 @@ def _parse_field_config(cls, field, field_config, record_list): "an field name." ) - vectors = Vectors(vocab_type, cache=".vector_cache") + if vocab_type in {"tx_embeddings_large.vec", "tx_embeddings.vec"}: + vectors = Vectors(vocab_type, cache=".vector_cache") + elif vocab_type == "fasttext": + vectors = FastText("en") # might need to add standard fasttext + else: + vocab.load_vectors(vocab_type) # won't work - vocab = torchtext.vocab.vocab(vocab_counter) + vocab = factory_vocab(vocab_counter) + # create vector tensor using tokens in vocab, order important vectors = [vectors] tot_dim = sum(v.dim for v in vectors) # 100 vector_tensor = Tensor(len(vocab), tot_dim) @@ -112,19 +119,9 @@ def _parse_field_config(cls, field, field_config, record_list): assert start_dim == tot_dim print(f"Vector tensor shape: {vector_tensor.shape}") + assert len(vector_tensor) == len(vocab) - print(len(vector_tensor)) - print(len(vocab)) - - print(nn.Embedding.from_pretrained(vector_tensor)) - - # pretrained_vectors = vectors # torchtext.vocab.FastText("en") - - #if vocab_type in {'tx_embeddings_large.vec', 'tx_embeddings.vec'}: - #vectors = Vectors(vocab_type, cache='.vector_cache') - #vocab.load_vectors(vectors) - #else: - #vocab.load_vectors(vocab_type) + print(nn.Embedding.from_pretrained(vector_tensor)) # check embedding works # Compute max_str_len if necessary if field_type in (FieldType.STRING, FieldType.MULTITOKEN) and (max_str_len is None): @@ -175,7 +172,6 @@ def _build_field_numericalizer(cls, field, field_config: FieldConfig): FieldType.SEMANTIC_STRING: SemanticStringNumericalizer, FieldType.SEMANTIC_MULTITOKEN: SemanticMultitokenNumericalizer, } - print(field_type_to_numericalizer_cls) numericalizer_cls = field_type_to_numericalizer_cls.get(field_type) if numericalizer_cls is None: raise ValueError(f"Unexpected field_type={field_type}") # pragma: no cover diff --git a/entity_embed/data_utils/numericalizer.py b/entity_embed/data_utils/numericalizer.py index 2bb6947..13f48f4 100644 --- a/entity_embed/data_utils/numericalizer.py +++ b/entity_embed/data_utils/numericalizer.py @@ -7,7 +7,7 @@ import numpy as np import regex import torch -from torchtext.vocab import Vocab, Vectors +from torchtext.vocab import Vocab logger = logging.getLogger(__name__) From d0020cea1dd0523c989a90d033dc7fea24639c5f Mon Sep 17 00:00:00 2001 From: Hannah Date: Thu, 27 Oct 2022 14:10:17 +0100 Subject: [PATCH 06/17] Add GPU param --- entity_embed/entity_embed.py | 8 ++++--- entity_embed/indexes.py | 43 ++++-------------------------------- 2 files changed, 9 insertions(+), 42 deletions(-) diff --git a/entity_embed/entity_embed.py b/entity_embed/entity_embed.py index adcbf44..af1daf4 100644 --- a/entity_embed/entity_embed.py +++ b/entity_embed/entity_embed.py @@ -167,6 +167,7 @@ def fit( tb_save_dir=None, tb_name=None, use_gpu=True, + accelerator=None, ): if early_stop_mode is None: if "pair_entity_ratio_at" in early_stop_monitor: @@ -199,7 +200,8 @@ def fit( } if use_gpu: trainer_args["gpus"] = 1 - + if accelerator: + trainer_args["accelerator"] = accelerator if tb_name and tb_save_dir: trainer_args["logger"] = TensorBoardLogger( tb_save_dir, @@ -211,9 +213,9 @@ def fit( "TensorBoardLogger or omit both to disable it" ) trainer = pl.Trainer(**trainer_args) - print("Trainer done") + trainer.fit(self, datamodule) - print("Model fit") + logger.info( "Loading the best validation model from " f"{trainer.checkpoint_callback.best_model_path}..." diff --git a/entity_embed/indexes.py b/entity_embed/indexes.py index 7045372..32a0f81 100644 --- a/entity_embed/indexes.py +++ b/entity_embed/indexes.py @@ -20,14 +20,10 @@ def __init__(self, embedding_size): def insert_vector_dict(self, vector_dict): for vector in vector_dict.values(): - # self.approx_knn_index.add_data(vector) - - # print(vector.dtype) - # print(vector.shape) - # print(repr(vector)) - vector = vector.reshape(1, 100) + vector = vector.reshape(1, len(vector)) # vector = faiss.normalize_L2(vector) self.approx_knn_index.add(vector) + # self.approx_knn_index.add_data(vector) self.vector_idx_to_id = dict(enumerate(vector_dict.keys())) def build( @@ -39,7 +35,6 @@ def build( # actual_index_build_kwargs = build_index_build_kwargs(index_build_kwargs) # self.approx_knn_index.build(**actual_index_build_kwargs) - print(self.approx_knn_index.ntotal) self.is_built = True # faiss.write_index(self.approx_knn_index, "vector.index") @@ -50,55 +45,25 @@ def search_pairs(self, k, sim_threshold, index_search_kwargs=None): raise ValueError(f"sim_threshold={sim_threshold} must be <= 1 and >= 0") logger.debug("Searching on approx_knn_index...") - print(sim_threshold) distance_threshold = 1 - sim_threshold - print(distance_threshold) index_search_kwargs = build_index_search_kwargs(index_search_kwargs) found_pair_set = set() - item_ids = self.vector_idx_to_id # .keys() - # print(item_ids) + item_ids = self.vector_idx_to_id for i in item_ids: vector = self.approx_knn_index.reconstruct(i).reshape(1, 100) - # print(vector.shape) - # print(i) similarities, neighbours = self.approx_knn_index.search(vector, k=k) left_id = self.vector_idx_to_id[i] - # print(similarities[0]) for similarity, j in zip(similarities[0], neighbours[0]): - # print(j) if i != j and similarity >= sim_threshold: right_id = self.vector_idx_to_id[j] # must use sorted to always have smaller id on left of pair tuple pair = tuple(sorted([left_id, right_id])) found_pair_set.add(pair) - print(f"found_pair_set: {len(found_pair_set)}") - logger.debug("Search on approx_knn_index done, building found_pair_set now...") - - if False: - neighbor_and_distance_list_of_list = self.approx_knn_index.batch_search_by_ids( - item_ids=self.vector_idx_to_id.keys(), - k=k, - include_distances=True, - **index_search_kwargs, - ) - - logger.debug("Search on approx_knn_index done, building found_pair_set now...") - - found_pair_set = set() - for i, neighbor_distance_list in enumerate(neighbor_and_distance_list_of_list): - left_id = self.vector_idx_to_id[i] - for j, distance in neighbor_distance_list: - if i != j and distance <= distance_threshold: - right_id = self.vector_idx_to_id[j] - # must use sorted to always have smaller id on left of pair tuple - pair = tuple(sorted([left_id, right_id])) - found_pair_set.add(pair) - logger.debug( - f"Building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." + f"Search on approx_knn_index and building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." ) return found_pair_set From 2bc4801c28593cea014079f92954a9d6942ec5a1 Mon Sep 17 00:00:00 2001 From: Hannah Date: Thu, 27 Oct 2022 17:00:24 +0100 Subject: [PATCH 07/17] Add l2 norm --- entity_embed/entity_embed.py | 4 ---- entity_embed/indexes.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/entity_embed/entity_embed.py b/entity_embed/entity_embed.py index af1daf4..0d95e0e 100644 --- a/entity_embed/entity_embed.py +++ b/entity_embed/entity_embed.py @@ -19,10 +19,6 @@ logger = logging.getLogger(__name__) -def hannah_test(n): - return n * 2 - - class _BaseEmbed(pl.LightningModule): def __init__( self, diff --git a/entity_embed/indexes.py b/entity_embed/indexes.py index 32a0f81..63b1b4f 100644 --- a/entity_embed/indexes.py +++ b/entity_embed/indexes.py @@ -21,7 +21,7 @@ def __init__(self, embedding_size): def insert_vector_dict(self, vector_dict): for vector in vector_dict.values(): vector = vector.reshape(1, len(vector)) - # vector = faiss.normalize_L2(vector) + faiss.normalize_L2(vector) self.approx_knn_index.add(vector) # self.approx_knn_index.add_data(vector) self.vector_idx_to_id = dict(enumerate(vector_dict.keys())) From c70237eb9352d757a5af7a055ef7b567adb248a2 Mon Sep 17 00:00:00 2001 From: Hannah Date: Tue, 1 Nov 2022 11:36:14 +0000 Subject: [PATCH 08/17] Add requirements --- hannah_requirements.txt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 hannah_requirements.txt diff --git a/hannah_requirements.txt b/hannah_requirements.txt new file mode 100644 index 0000000..8437c66 --- /dev/null +++ b/hannah_requirements.txt @@ -0,0 +1,15 @@ +click==8.0.4 +faiss==1.5.3 +mock==4.0.3 +more_itertools==9.0.0 +numpy==1.21.5 +ordered_set==4.1.0 +pytest==7.1.1 +pytorch_lightning==1.7.7 +pytorch_metric_learning==1.6.2 +regex==2022.3.15 +setuptools==61.2.0 +sphinx_rtd_theme==1.0.0 +torch==1.13.0 +torchtext==0.14.0 +tqdm==4.64.0 From 23b869c2290c8ff7efad84d6c5967279c495bd7e Mon Sep 17 00:00:00 2001 From: Hannah Date: Tue, 1 Nov 2022 12:04:09 +0000 Subject: [PATCH 09/17] Add print statements throughout for device debugging --- entity_embed/cli.py | 3 ++- entity_embed/data_utils/field_config_parser.py | 10 ++++++++-- entity_embed/entity_embed.py | 7 ++++++- entity_embed/models.py | 1 + requirements.txt | 4 +++- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/entity_embed/cli.py b/entity_embed/cli.py index 2aebb02..b2bff4e 100644 --- a/entity_embed/cli.py +++ b/entity_embed/cli.py @@ -369,7 +369,8 @@ def _load_model(kwargs): model = model_cls.load_from_checkpoint(kwargs["model_save_filepath"], datamodule=None) if kwargs["use_gpu"]: - model = model.to(torch.device("cuda")) + # model = model.to(torch.device("cuda")) + model = model.to(torch.device("mps")) else: model = model.to(torch.device("cpu")) return model diff --git a/entity_embed/data_utils/field_config_parser.py b/entity_embed/data_utils/field_config_parser.py index 9aa7284..2ca3b35 100644 --- a/entity_embed/data_utils/field_config_parser.py +++ b/entity_embed/data_utils/field_config_parser.py @@ -2,6 +2,7 @@ import logging from importlib import import_module +import torch from torch import Tensor, nn from torchtext.vocab import Vocab, Vectors, FastText from torchtext.vocab import vocab as factory_vocab @@ -107,8 +108,10 @@ def _parse_field_config(cls, field, field_config, record_list): # create vector tensor using tokens in vocab, order important vectors = [vectors] - tot_dim = sum(v.dim for v in vectors) # 100 - vector_tensor = Tensor(len(vocab), tot_dim) + # device = torch.device("mps") + tot_dim = sum(v.dim for v in vectors) + # vector_tensor = torch.zeros(len(vocab), tot_dim) + vector_tensor = torch.Tensor(len(vocab), tot_dim) for i, token in enumerate(vocab.get_itos()): start_dim = 0 @@ -119,6 +122,9 @@ def _parse_field_config(cls, field, field_config, record_list): assert start_dim == tot_dim print(f"Vector tensor shape: {vector_tensor.shape}") + print(f"Vector tensor type: {vector_tensor.shape}") + print(f"Vector tensor type: {vector_tensor.device}") + print(f"Vector tensor type: {vector_tensor.dtype}") assert len(vector_tensor) == len(vocab) print(nn.Embedding.from_pretrained(vector_tensor)) # check embedding works diff --git a/entity_embed/entity_embed.py b/entity_embed/entity_embed.py index 0d95e0e..1c9cabd 100644 --- a/entity_embed/entity_embed.py +++ b/entity_embed/entity_embed.py @@ -73,11 +73,13 @@ def __init__( self.sim_threshold_list = sim_threshold_list self.index_build_kwargs = index_build_kwargs self.index_search_kwargs = index_search_kwargs + self._dev = "mps" + print(self.device) + print(self._dev) def forward(self, tensor_dict, sequence_length_dict, return_field_embeddings=False): tensor_dict = utils.tensor_dict_to_device(tensor_dict, device=self.device) sequence_length_dict = utils.tensor_dict_to_device(sequence_length_dict, device=self.device) - return self.blocker_net(tensor_dict, sequence_length_dict, return_field_embeddings) def _warn_if_empty_indices_tuple(self, indices_tuple, batch_idx): @@ -194,6 +196,7 @@ def fit( "callbacks": [early_stop_callback, checkpoint_callback], "reload_dataloaders_every_n_epochs": 10, # for shuffling ClusterDataset every epoch } + print(self.device) if use_gpu: trainer_args["gpus"] = 1 if accelerator: @@ -217,6 +220,8 @@ def fit( f"{trainer.checkpoint_callback.best_model_path}..." ) self.blocker_net = None + print(self.device) + print(self._dev) best_model = self.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) best_model = best_model.to(self.device) self.blocker_net = best_model.blocker_net diff --git a/entity_embed/models.py b/entity_embed/models.py index e78e735..e69b5cb 100644 --- a/entity_embed/models.py +++ b/entity_embed/models.py @@ -151,6 +151,7 @@ def forward(self, x, sequence_lengths, **kwargs): x_list = x.unbind(dim=1) x_list = [self.embed_net(x) for x in x_list] x = torch.stack(x_list, dim=1) + print(x.device) # Compute a mask for the attention on the padded sequences # See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5 diff --git a/requirements.txt b/requirements.txt index e25e7b2..d17f5b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ click==7.1.2,<8.0 more-itertools>=8.6.0,<9.0 -# n2>=0.1.7,<1.2 numpy>=1.19.0 ordered-set>=4.0.2 pytorch_lightning>=1.1.6 @@ -10,3 +9,6 @@ torch>=1.7.1 torchtext>=0.8 torchvision>=0.8.2 tqdm>=4.53.0 + +# conda install grpcio +# pip install faiss-cpu \ No newline at end of file From 6b967b6938fe0d71f46c21f90201bc7f85a2b3d9 Mon Sep 17 00:00:00 2001 From: Benj Pettit Date: Tue, 1 Nov 2022 13:52:15 +0000 Subject: [PATCH 10/17] feat: L2 normalization of vectors to produce cosine-similarity ANN search. torch package upgrades. added FAISS dependency. --- entity_embed/indexes.py | 11 +++++------ requirements.txt | 13 ++++++++----- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/entity_embed/indexes.py b/entity_embed/indexes.py index 32a0f81..bbdcc62 100644 --- a/entity_embed/indexes.py +++ b/entity_embed/indexes.py @@ -1,5 +1,6 @@ import logging import faiss +import numpy as np # from n2 import HnswIndex @@ -19,11 +20,10 @@ def __init__(self, embedding_size): print(self.approx_knn_index.is_trained) def insert_vector_dict(self, vector_dict): - for vector in vector_dict.values(): - vector = vector.reshape(1, len(vector)) - # vector = faiss.normalize_L2(vector) - self.approx_knn_index.add(vector) - # self.approx_knn_index.add_data(vector) + vector_array = np.array(list(vector_dict.values())) + l2_norm = np.linalg.norm(vector_array, ord=2, axis=1).reshape(vector_array.shape[0], 1) + vector_array_normalized = vector_array / l2_norm + self.approx_knn_index.add(vector_array_normalized) self.vector_idx_to_id = dict(enumerate(vector_dict.keys())) def build( @@ -68,7 +68,6 @@ def search_pairs(self, k, sim_threshold, index_search_kwargs=None): return found_pair_set - # class ANNEntityIndex: # def __init__(self, embedding_size): # self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") diff --git a/requirements.txt b/requirements.txt index e25e7b2..2193f53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,13 @@ more-itertools>=8.6.0,<9.0 # n2>=0.1.7,<1.2 numpy>=1.19.0 ordered-set>=4.0.2 -pytorch_lightning>=1.1.6 -pytorch-metric-learning>=0.9.98 regex>=2020.11.13 -torch>=1.7.1 -torchtext>=0.8 -torchvision>=0.8.2 +pytorch-lightning==1.7.7 +pytorch-metric-learning>=0.9.99 +torch==1.12.1 +torchmetrics>=0.10.1 +torchtext==0.13.1 +torchvision>=0.13.1 tqdm>=4.53.0 +faiss_cpu==1.7.2 + From 18c7aae0906f8eecc0c17f83dfae0b3f7b4c235d Mon Sep 17 00:00:00 2001 From: Hannah Date: Tue, 1 Nov 2022 18:32:14 +0000 Subject: [PATCH 11/17] Add default token to vocab: --- entity_embed/data_utils/field_config_parser.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/entity_embed/data_utils/field_config_parser.py b/entity_embed/data_utils/field_config_parser.py index 2ca3b35..fdc204f 100644 --- a/entity_embed/data_utils/field_config_parser.py +++ b/entity_embed/data_utils/field_config_parser.py @@ -104,7 +104,13 @@ def _parse_field_config(cls, field, field_config, record_list): else: vocab.load_vectors(vocab_type) # won't work - vocab = factory_vocab(vocab_counter) + # adding token + unk_token = "" + vocab = factory_vocab(vocab_counter, specials=[unk_token]) + # print(vocab[""]) # prints 0 + # make default index same as index of unk_token + vocab.set_default_index(vocab[unk_token]) + # print(vocab["probably out of vocab"]) # prints 0 # create vector tensor using tokens in vocab, order important vectors = [vectors] From 83715a1ba0319eaad8e4ae08837c1cc0b8010974 Mon Sep 17 00:00:00 2001 From: Hannah Date: Wed, 2 Nov 2022 13:05:42 +0000 Subject: [PATCH 12/17] Add truncation --- entity_embed/data_utils/numericalizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/entity_embed/data_utils/numericalizer.py b/entity_embed/data_utils/numericalizer.py index 13f48f4..a4e3018 100644 --- a/entity_embed/data_utils/numericalizer.py +++ b/entity_embed/data_utils/numericalizer.py @@ -103,9 +103,11 @@ def build_tensor(self, val): # with characters as rows and positions as columns. # This is the shape expected by StringEmbedCNN. ord_encoded_val = self._ord_encode(val) + ord_encoded_val = ord_encoded_val[: self.max_str_len] # truncate to max_str_len encoded_arr = np.zeros((len(self.alphabet), self.max_str_len), dtype=np.float32) if len(ord_encoded_val) > 0: encoded_arr[ord_encoded_val, range(len(ord_encoded_val))] = 1.0 + t = torch.from_numpy(encoded_arr) return t, len(val) From 103c8070862b4c01115fd44bdbbf09bfbf200621 Mon Sep 17 00:00:00 2001 From: Benjamin Pettit Date: Thu, 3 Nov 2022 09:45:14 +0000 Subject: [PATCH 13/17] feat(evaluation): model-agnostic embedding evaluator --- entity_embed/evaluation.py | 23 +++++++++++++++++++++++ entity_embed/indexes.py | 13 ++++++------- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/entity_embed/evaluation.py b/entity_embed/evaluation.py index 87ed12d..34a0f4c 100755 --- a/entity_embed/evaluation.py +++ b/entity_embed/evaluation.py @@ -1,5 +1,8 @@ import csv import json +from .indexes import ANNEntityIndex +from .data_utils import utils +import pandas as pd def pair_entity_ratio(found_pair_set_len, entity_count): @@ -53,3 +56,23 @@ def evaluate_output_json( f1_score(precision, recall), pair_entity_ratio(len(found_pair_set), record_count), ) + + +class EmbeddingEvaluator: + def __init__(self, record_dict, vector_dict, cluster_field='cluster_id'): + self.record_dict = record_dict + self.cluster_field = cluster_field + embedding_size = len(next(iter(vector_dict.values()))) + self.ann_index = ANNEntityIndex(embedding_size) + self.ann_index.insert_vector_dict(vector_dict) + self.ann_index.build() + + def evaluate(self, k, sim_thresholds): + cluster_dict = utils.record_dict_to_cluster_dict(self.record_dict, self.cluster_field) + pos_pair_set = utils.cluster_dict_to_id_pairs(cluster_dict) + results = [] + for sim_threshold in sim_thresholds: + found_pair_set = self.ann_index.search_pairs(k, sim_threshold) + precision, recall = precision_and_recall(found_pair_set, pos_pair_set) + results.append((sim_threshold, precision, recall, f1_score(precision, recall))) + return pd.DataFrame(results, columns=['threshold', 'precision', 'recall', 'f1_score']) diff --git a/entity_embed/indexes.py b/entity_embed/indexes.py index bbdcc62..86b9752 100644 --- a/entity_embed/indexes.py +++ b/entity_embed/indexes.py @@ -16,14 +16,15 @@ def __init__(self, embedding_size): ) # self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") self.vector_idx_to_id = None + self.normalized_vector_array = None self.is_built = False - print(self.approx_knn_index.is_trained) + self.embedding_size = embedding_size def insert_vector_dict(self, vector_dict): vector_array = np.array(list(vector_dict.values())) l2_norm = np.linalg.norm(vector_array, ord=2, axis=1).reshape(vector_array.shape[0], 1) - vector_array_normalized = vector_array / l2_norm - self.approx_knn_index.add(vector_array_normalized) + self.normalized_vector_array = vector_array / l2_norm + self.approx_knn_index.add(self.normalized_vector_array) self.vector_idx_to_id = dict(enumerate(vector_dict.keys())) def build( @@ -50,11 +51,9 @@ def search_pairs(self, k, sim_threshold, index_search_kwargs=None): index_search_kwargs = build_index_search_kwargs(index_search_kwargs) found_pair_set = set() - item_ids = self.vector_idx_to_id - for i in item_ids: - vector = self.approx_knn_index.reconstruct(i).reshape(1, 100) + for i, left_id in self.vector_idx_to_id.items(): + vector = self.normalized_vector_array[[i], :] similarities, neighbours = self.approx_knn_index.search(vector, k=k) - left_id = self.vector_idx_to_id[i] for similarity, j in zip(similarities[0], neighbours[0]): if i != j and similarity >= sim_threshold: right_id = self.vector_idx_to_id[j] From 5621e974ca7ab75ec5bd7c0dec0de8bc9158138e Mon Sep 17 00:00:00 2001 From: Benjamin Pettit Date: Fri, 4 Nov 2022 13:27:55 +0000 Subject: [PATCH 14/17] feat(evaluation): evaluator takes an optional list of query_ids --- entity_embed/evaluation.py | 29 ++++++++++++++++++++++++----- entity_embed/indexes.py | 19 ++++++++++--------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/entity_embed/evaluation.py b/entity_embed/evaluation.py index 34a0f4c..4f89457 100755 --- a/entity_embed/evaluation.py +++ b/entity_embed/evaluation.py @@ -1,5 +1,6 @@ import csv import json +import random from .indexes import ANNEntityIndex from .data_utils import utils import pandas as pd @@ -66,13 +67,31 @@ def __init__(self, record_dict, vector_dict, cluster_field='cluster_id'): self.ann_index = ANNEntityIndex(embedding_size) self.ann_index.insert_vector_dict(vector_dict) self.ann_index.build() + self.cluster_dict = utils.record_dict_to_cluster_dict(self.record_dict, self.cluster_field) + self.pos_pair_set = utils.cluster_dict_to_id_pairs(self.cluster_dict) + + def evaluate(self, k, sim_thresholds, query_ids=None): + """ + params: + k: int: number of nearest neighbours to retrieve + sim_thresholds: list of floats in the range [0,1]: + query_ids: list or set of ids that must be keys in self.vector_dict and self.record_dict. Indicates + which ids to find pairs for. If None, use all record ids as query ids + + returns: pandas DataFrame of results, with one row for each threshold + """ + if query_ids is None: + print(f"Using all {len(self.record_dict)} records to query for neighbours") + pos_pair_subset = self.pos_pair_set + else: + query_ids = set(query_ids) + print(f"Using subset of {len(query_ids)} query IDs") + pos_pair_subset = {pair for pair in self.pos_pair_set + if pair[0] in query_ids or pair[1] in query_ids} - def evaluate(self, k, sim_thresholds): - cluster_dict = utils.record_dict_to_cluster_dict(self.record_dict, self.cluster_field) - pos_pair_set = utils.cluster_dict_to_id_pairs(cluster_dict) results = [] for sim_threshold in sim_thresholds: - found_pair_set = self.ann_index.search_pairs(k, sim_threshold) - precision, recall = precision_and_recall(found_pair_set, pos_pair_set) + found_pair_set = self.ann_index.search_pairs(k, sim_threshold, query_id_subset=query_ids) + precision, recall = precision_and_recall(found_pair_set, pos_pair_subset) results.append((sim_threshold, precision, recall, f1_score(precision, recall))) return pd.DataFrame(results, columns=['threshold', 'precision', 'recall', 'f1_score']) diff --git a/entity_embed/indexes.py b/entity_embed/indexes.py index 86b9752..11d5ef0 100644 --- a/entity_embed/indexes.py +++ b/entity_embed/indexes.py @@ -39,7 +39,7 @@ def build( self.is_built = True # faiss.write_index(self.approx_knn_index, "vector.index") - def search_pairs(self, k, sim_threshold, index_search_kwargs=None): + def search_pairs(self, k, sim_threshold, index_search_kwargs=None, query_id_subset=None): if not self.is_built: raise ValueError("Please call build first") if sim_threshold > 1 or sim_threshold < 0: @@ -52,14 +52,15 @@ def search_pairs(self, k, sim_threshold, index_search_kwargs=None): found_pair_set = set() for i, left_id in self.vector_idx_to_id.items(): - vector = self.normalized_vector_array[[i], :] - similarities, neighbours = self.approx_knn_index.search(vector, k=k) - for similarity, j in zip(similarities[0], neighbours[0]): - if i != j and similarity >= sim_threshold: - right_id = self.vector_idx_to_id[j] - # must use sorted to always have smaller id on left of pair tuple - pair = tuple(sorted([left_id, right_id])) - found_pair_set.add(pair) + if query_id_subset is None or left_id in query_id_subset: + vector = self.normalized_vector_array[[i], :] + similarities, neighbours = self.approx_knn_index.search(vector, k=k) + for similarity, j in zip(similarities[0], neighbours[0]): + if i != j and similarity >= sim_threshold: + right_id = self.vector_idx_to_id[j] + # must use sorted to always have smaller id on left of pair tuple + pair = tuple(sorted([left_id, right_id])) + found_pair_set.add(pair) logger.debug( f"Search on approx_knn_index and building found_pair_set done. Found len(found_pair_set)={len(found_pair_set)} pairs." From 6ab22dee9e669489e142e56e92f53a337851e509 Mon Sep 17 00:00:00 2001 From: Hannah Date: Fri, 4 Nov 2022 15:35:42 +0000 Subject: [PATCH 15/17] feat(evaluation): add missing pair sets --- entity_embed/evaluation.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/entity_embed/evaluation.py b/entity_embed/evaluation.py index 4f89457..a6a60cb 100755 --- a/entity_embed/evaluation.py +++ b/entity_embed/evaluation.py @@ -60,7 +60,7 @@ def evaluate_output_json( class EmbeddingEvaluator: - def __init__(self, record_dict, vector_dict, cluster_field='cluster_id'): + def __init__(self, record_dict, vector_dict, cluster_field="cluster_id"): self.record_dict = record_dict self.cluster_field = cluster_field embedding_size = len(next(iter(vector_dict.values()))) @@ -70,7 +70,7 @@ def __init__(self, record_dict, vector_dict, cluster_field='cluster_id'): self.cluster_dict = utils.record_dict_to_cluster_dict(self.record_dict, self.cluster_field) self.pos_pair_set = utils.cluster_dict_to_id_pairs(self.cluster_dict) - def evaluate(self, k, sim_thresholds, query_ids=None): + def evaluate(self, k, sim_thresholds, query_ids=None, get_missing_pair_set=False): """ params: k: int: number of nearest neighbours to retrieve @@ -86,12 +86,24 @@ def evaluate(self, k, sim_thresholds, query_ids=None): else: query_ids = set(query_ids) print(f"Using subset of {len(query_ids)} query IDs") - pos_pair_subset = {pair for pair in self.pos_pair_set - if pair[0] in query_ids or pair[1] in query_ids} - + pos_pair_subset = { + pair for pair in self.pos_pair_set if pair[0] in query_ids or pair[1] in query_ids + } results = [] for sim_threshold in sim_thresholds: - found_pair_set = self.ann_index.search_pairs(k, sim_threshold, query_id_subset=query_ids) + found_pair_set = self.ann_index.search_pairs( + k, sim_threshold, query_id_subset=query_ids + ) precision, recall = precision_and_recall(found_pair_set, pos_pair_subset) results.append((sim_threshold, precision, recall, f1_score(precision, recall))) - return pd.DataFrame(results, columns=['threshold', 'precision', 'recall', 'f1_score']) + if get_missing_pair_set & (sim_threshold == min(sim_thresholds)): + self.missing_pair_set = pos_pair_subset - found_pair_set + id_to_name_map = {k: v["merchant_name"] for k, v in self.record_dict.items()} + self.missing_pair_name_set = set( + map( + lambda x: (id_to_name_map[x[0]], id_to_name_map[x[1]]), + self.missing_pair_set, + ) + ) + + return pd.DataFrame(results, columns=["threshold", "precision", "recall", "f1_score"]) From 35e386f31485abe845a8effa376b47b5bdac519d Mon Sep 17 00:00:00 2001 From: Hannah Date: Mon, 7 Nov 2022 17:19:49 +0000 Subject: [PATCH 16/17] feat: add logging --- entity_embed/evaluation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/entity_embed/evaluation.py b/entity_embed/evaluation.py index a6a60cb..ae5c189 100755 --- a/entity_embed/evaluation.py +++ b/entity_embed/evaluation.py @@ -4,7 +4,9 @@ from .indexes import ANNEntityIndex from .data_utils import utils import pandas as pd +import logging +logger = logging.getLogger(__name__) def pair_entity_ratio(found_pair_set_len, entity_count): return found_pair_set_len / entity_count @@ -64,10 +66,13 @@ def __init__(self, record_dict, vector_dict, cluster_field="cluster_id"): self.record_dict = record_dict self.cluster_field = cluster_field embedding_size = len(next(iter(vector_dict.values()))) + logging.info("Building index...") self.ann_index = ANNEntityIndex(embedding_size) self.ann_index.insert_vector_dict(vector_dict) self.ann_index.build() + logging.info("Index built! Getting cluster dict...") self.cluster_dict = utils.record_dict_to_cluster_dict(self.record_dict, self.cluster_field) + logging.info("Getting positive pairs...") self.pos_pair_set = utils.cluster_dict_to_id_pairs(self.cluster_dict) def evaluate(self, k, sim_thresholds, query_ids=None, get_missing_pair_set=False): @@ -81,11 +86,11 @@ def evaluate(self, k, sim_thresholds, query_ids=None, get_missing_pair_set=False returns: pandas DataFrame of results, with one row for each threshold """ if query_ids is None: - print(f"Using all {len(self.record_dict)} records to query for neighbours") + logging.info(f"Using all {len(self.record_dict)} records to query for neighbours") pos_pair_subset = self.pos_pair_set else: query_ids = set(query_ids) - print(f"Using subset of {len(query_ids)} query IDs") + logging.info(f"Using subset of {len(query_ids)} query IDs") pos_pair_subset = { pair for pair in self.pos_pair_set if pair[0] in query_ids or pair[1] in query_ids } From 3e8e0252fd1ba23994102b2f23c02d3ba55525d4 Mon Sep 17 00:00:00 2001 From: Hannah Date: Fri, 18 Nov 2022 11:35:39 +0000 Subject: [PATCH 17/17] feat(preprocessor): add preprocessor feature --- entity_embed/__init__.py | 7 +++- .../data_utils/field_config_parser.py | 15 ++++--- entity_embed/data_utils/numericalizer.py | 41 +++++++++++++++++++ entity_embed/early_stopping.py | 2 +- entity_embed/entity_embed.py | 26 +++++++----- entity_embed/evaluation.py | 12 ++++-- entity_embed/indexes.py | 5 +++ requirements.txt | 1 - 8 files changed, 86 insertions(+), 23 deletions(-) diff --git a/entity_embed/__init__.py b/entity_embed/__init__.py index 9ce1ae6..4a5f21d 100644 --- a/entity_embed/__init__.py +++ b/entity_embed/__init__.py @@ -6,7 +6,12 @@ from .data_modules import * # noqa: F401, F403 from .data_utils.field_config_parser import FieldConfigDictParser # noqa: F401 -from .data_utils.numericalizer import default_tokenizer # noqa: F401 +from .data_utils.numericalizer import ( + default_tokenizer, + remove_space_digit_punc, + remove_places, + default_pre_processor, +) # noqa: F401 from .entity_embed import * # noqa: F401, F403 from .indexes import * # noqa: F401, F403 diff --git a/entity_embed/data_utils/field_config_parser.py b/entity_embed/data_utils/field_config_parser.py index fdc204f..b7a6b54 100644 --- a/entity_embed/data_utils/field_config_parser.py +++ b/entity_embed/data_utils/field_config_parser.py @@ -66,6 +66,12 @@ def _parse_field_config(cls, field, field_config, record_list): tokenizer = _import_function( field_config.get("tokenizer", "entity_embed.default_tokenizer") ) + pre_processor = _import_function( + field_config.get("pre_processor", "entity_embed.default_pre_processor") + ) + multi_pre_processor = _import_function( + field_config.get("multi_pre_processor", "entity_embed.default_pre_processor") + ) alphabet = field_config.get("alphabet", DEFAULT_ALPHABET) max_str_len = field_config.get("max_str_len") vocab = None @@ -127,14 +133,9 @@ def _parse_field_config(cls, field, field_config, record_list): start_dim = end_dim assert start_dim == tot_dim - print(f"Vector tensor shape: {vector_tensor.shape}") - print(f"Vector tensor type: {vector_tensor.shape}") - print(f"Vector tensor type: {vector_tensor.device}") - print(f"Vector tensor type: {vector_tensor.dtype}") + logger.info(f"Vector tensor shape: {vector_tensor.shape}") assert len(vector_tensor) == len(vocab) - print(nn.Embedding.from_pretrained(vector_tensor)) # check embedding works - # Compute max_str_len if necessary if field_type in (FieldType.STRING, FieldType.MULTITOKEN) and (max_str_len is None): logger.info(f"For field={field}, computing actual max_str_len") @@ -165,6 +166,8 @@ def _parse_field_config(cls, field, field_config, record_list): key=key, field_type=field_type, tokenizer=tokenizer, + pre_processor=pre_processor, + multi_pre_processor=multi_pre_processor, alphabet=alphabet, max_str_len=max_str_len, vocab=vocab, diff --git a/entity_embed/data_utils/numericalizer.py b/entity_embed/data_utils/numericalizer.py index a4e3018..14246d8 100644 --- a/entity_embed/data_utils/numericalizer.py +++ b/entity_embed/data_utils/numericalizer.py @@ -4,10 +4,12 @@ from enum import Enum from typing import Callable, List +from string import punctuation import numpy as np import regex import torch from torchtext.vocab import Vocab +from flashgeotext.geotext import GeoText, GeoTextConfiguration logger = logging.getLogger(__name__) @@ -42,6 +44,8 @@ class FieldType(Enum): class FieldConfig: key: str field_type: FieldType + pre_processor: Callable[[str], List[str]] + multi_pre_processor: Callable[[str], List[str]] tokenizer: Callable[[str], List[str]] alphabet: List[str] max_str_len: int @@ -79,6 +83,31 @@ def default_tokenizer(val): return tokenizer_re.findall(val) +def remove_space_digit_punc(val): + val = "".join(c for c in val if (not c.isdigit()) and (c not in punctuation)) + return val.replace(" ", "") + + +config = GeoTextConfiguration(**{"case_sensitive": False}) +geotext = GeoText(config) + + +def default_pre_processor(text): + return text + + +def remove_places(text): + places = geotext.extract(text) + found_places = [] + for i, v in places.items(): + for w, x in v.items(): + word = x["found_as"][0] + if word not in ["at", "com", "us", "usa"]: + found_places.append(word) + text = text.replace(word, "") + return text + + class StringNumericalizer: is_multitoken = False @@ -87,6 +116,8 @@ def __init__(self, field, field_config): self.alphabet = field_config.alphabet self.max_str_len = field_config.max_str_len self.char_to_ord = {c: i for i, c in enumerate(self.alphabet)} + self.pre_processor = field_config.pre_processor + print(f"Found pre_processor {self.pre_processor} for field {self.field}") def _ord_encode(self, val): ord_encoded = [] @@ -102,6 +133,9 @@ def build_tensor(self, val): # encoded_arr is a one hot encoded bidimensional tensor # with characters as rows and positions as columns. # This is the shape expected by StringEmbedCNN. + # if val != self.pre_processor(val): + # print(f"{val} -> {self.pre_processor(val)} -> {self.pre_processor} -> {self.field}") + val = self.pre_processor(val) ord_encoded_val = self._ord_encode(val) ord_encoded_val = ord_encoded_val[: self.max_str_len] # truncate to max_str_len encoded_arr = np.zeros((len(self.alphabet), self.max_str_len), dtype=np.float32) @@ -131,10 +165,16 @@ class MultitokenNumericalizer: def __init__(self, field, field_config): self.field = field + self.field_type = field_config.field_type + self.multi_pre_processor = field_config.multi_pre_processor self.tokenizer = field_config.tokenizer self.string_numericalizer = StringNumericalizer(field=field, field_config=field_config) + print(f"Found multi_pre_processor {self.multi_pre_processor} for field {self.field}") def build_tensor(self, val): + # if val != self.multi_pre_processor(val): + # print(f"{val} -> {self.multi_pre_processor(val)} -> {self.multi_pre_processor} -> {self.field}") + val = self.multi_pre_processor(val) val_tokens = self.tokenizer(val) t_list = [] for v in val_tokens: @@ -153,6 +193,7 @@ class SemanticMultitokenNumericalizer(MultitokenNumericalizer): def __init__(self, field, field_config): self.field = field self.tokenizer = field_config.tokenizer + self.multi_pre_processor = field_config.multi_pre_processor self.string_numericalizer = SemanticStringNumericalizer( field=field, field_config=field_config ) diff --git a/entity_embed/early_stopping.py b/entity_embed/early_stopping.py index 5e99b98..5edd788 100644 --- a/entity_embed/early_stopping.py +++ b/entity_embed/early_stopping.py @@ -38,7 +38,7 @@ def __init__( dirpath=None, filename=None, verbose=False, - save_last=None, + save_last=True, save_top_k=None, save_weights_only=False, period=1, diff --git a/entity_embed/entity_embed.py b/entity_embed/entity_embed.py index 1c9cabd..c613103 100644 --- a/entity_embed/entity_embed.py +++ b/entity_embed/entity_embed.py @@ -74,8 +74,6 @@ def __init__( self.index_build_kwargs = index_build_kwargs self.index_search_kwargs = index_search_kwargs self._dev = "mps" - print(self.device) - print(self._dev) def forward(self, tensor_dict, sequence_length_dict, return_field_embeddings=False): tensor_dict = utils.tensor_dict_to_device(tensor_dict, device=self.device) @@ -154,6 +152,7 @@ def fit( min_epochs=5, max_epochs=100, check_val_every_n_epoch=1, + use_early_stop=False, early_stop_monitor="valid_recall_at_0.3", early_stop_min_delta=0.0, early_stop_patience=20, @@ -161,18 +160,19 @@ def fit( early_stop_verbose=True, model_save_top_k=1, model_save_dir=None, + model_save_filename=None, model_save_verbose=False, tb_save_dir=None, tb_name=None, - use_gpu=True, - accelerator=None, + use_gpu=False, + accelerator="cpu", + ckpt_path=None, ): if early_stop_mode is None: if "pair_entity_ratio_at" in early_stop_monitor: early_stop_mode = "min" else: early_stop_mode = "max" - early_stop_callback = EarlyStoppingMinEpochs( min_epochs=min_epochs, monitor=early_stop_monitor, @@ -181,22 +181,27 @@ def fit( mode=early_stop_mode, verbose=early_stop_verbose, ) + callbacks = [] + if use_early_stop: + callbacks.append(early_stop_callback) + print("Using early stopping callback...") checkpoint_callback = ModelCheckpointMinEpochs( min_epochs=min_epochs, monitor=early_stop_monitor, save_top_k=model_save_top_k, mode=early_stop_mode, dirpath=model_save_dir, + filename=model_save_filename, verbose=model_save_verbose, ) + callbacks.append(checkpoint_callback) trainer_args = { "min_epochs": min_epochs, "max_epochs": max_epochs, "check_val_every_n_epoch": check_val_every_n_epoch, - "callbacks": [early_stop_callback, checkpoint_callback], + "callbacks": callbacks, "reload_dataloaders_every_n_epochs": 10, # for shuffling ClusterDataset every epoch } - print(self.device) if use_gpu: trainer_args["gpus"] = 1 if accelerator: @@ -211,17 +216,18 @@ def fit( 'Please provide both "tb_name" and "tb_save_dir" to enable ' "TensorBoardLogger or omit both to disable it" ) + fit_args = {} + if ckpt_path: + fit_args["ckpt_path"] = ckpt_path trainer = pl.Trainer(**trainer_args) - trainer.fit(self, datamodule) + trainer.fit(self, datamodule, **fit_args) logger.info( "Loading the best validation model from " f"{trainer.checkpoint_callback.best_model_path}..." ) self.blocker_net = None - print(self.device) - print(self._dev) best_model = self.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) best_model = best_model.to(self.device) self.blocker_net = best_model.blocker_net diff --git a/entity_embed/evaluation.py b/entity_embed/evaluation.py index ae5c189..4cf767d 100755 --- a/entity_embed/evaluation.py +++ b/entity_embed/evaluation.py @@ -8,6 +8,7 @@ logger = logging.getLogger(__name__) + def pair_entity_ratio(found_pair_set_len, entity_count): return found_pair_set_len / entity_count @@ -70,9 +71,8 @@ def __init__(self, record_dict, vector_dict, cluster_field="cluster_id"): self.ann_index = ANNEntityIndex(embedding_size) self.ann_index.insert_vector_dict(vector_dict) self.ann_index.build() - logging.info("Index built! Getting cluster dict...") + logging.info("Index built!") self.cluster_dict = utils.record_dict_to_cluster_dict(self.record_dict, self.cluster_field) - logging.info("Getting positive pairs...") self.pos_pair_set = utils.cluster_dict_to_id_pairs(self.cluster_dict) def evaluate(self, k, sim_thresholds, query_ids=None, get_missing_pair_set=False): @@ -100,7 +100,9 @@ def evaluate(self, k, sim_thresholds, query_ids=None, get_missing_pair_set=False k, sim_threshold, query_id_subset=query_ids ) precision, recall = precision_and_recall(found_pair_set, pos_pair_subset) - results.append((sim_threshold, precision, recall, f1_score(precision, recall))) + results.append( + (sim_threshold, precision, recall, f1_score(precision, recall), len(found_pair_set)) + ) if get_missing_pair_set & (sim_threshold == min(sim_thresholds)): self.missing_pair_set = pos_pair_subset - found_pair_set id_to_name_map = {k: v["merchant_name"] for k, v in self.record_dict.items()} @@ -111,4 +113,6 @@ def evaluate(self, k, sim_thresholds, query_ids=None, get_missing_pair_set=False ) ) - return pd.DataFrame(results, columns=["threshold", "precision", "recall", "f1_score"]) + return pd.DataFrame( + results, columns=["threshold", "precision", "recall", "f1_score", "n_pairs_found"] + ) diff --git a/entity_embed/indexes.py b/entity_embed/indexes.py index 11d5ef0..9363152 100644 --- a/entity_embed/indexes.py +++ b/entity_embed/indexes.py @@ -55,6 +55,10 @@ def search_pairs(self, k, sim_threshold, index_search_kwargs=None, query_id_subs if query_id_subset is None or left_id in query_id_subset: vector = self.normalized_vector_array[[i], :] similarities, neighbours = self.approx_knn_index.search(vector, k=k) + if all(similarities[0] >= sim_threshold) & (sim_threshold > 0.4): + print( + f"Found pair similarities for k = {k} are all higher than threshold {sim_threshold}" + ) for similarity, j in zip(similarities[0], neighbours[0]): if i != j and similarity >= sim_threshold: right_id = self.vector_idx_to_id[j] @@ -68,6 +72,7 @@ def search_pairs(self, k, sim_threshold, index_search_kwargs=None, query_id_subs return found_pair_set + # class ANNEntityIndex: # def __init__(self, embedding_size): # self.approx_knn_index = HnswIndex(dimension=embedding_size, metric="angular") diff --git a/requirements.txt b/requirements.txt index a25190d..4384bba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,6 @@ pytorch-metric-learning>=0.9.99 torch==1.12.1 torchmetrics>=0.10.1 torchtext==0.13.1 -torchvision>=0.13.1 tqdm>=4.53.0 # conda install grpcio