Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hugging Face relevance training script #84

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions hugging_face/example/clean_data.py

This file was deleted.

43 changes: 43 additions & 0 deletions hugging_face/example/split_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import csv

Check warning on line 1 in hugging_face/example/split_data.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 Missing docstring in public module Raw Output: ./hugging_face/example/split_data.py:1:1: D100 Missing docstring in public module
import random
import sys


""" This script randomly seperates a csv into a train and test split for use in training.

Check failure on line 6 in hugging_face/example/split_data.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (89 > 79 characters) Raw Output: ./hugging_face/example/split_data.py:6:80: E501 line too long (89 > 79 characters)
The script will filter out rows containing multiple labels and preservve at least one unique label for the test script.

Check failure on line 7 in hugging_face/example/split_data.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (123 > 79 characters) Raw Output: ./hugging_face/example/split_data.py:7:80: E501 line too long (123 > 79 characters)
"""


labels = set()
csv.field_size_limit(sys.maxsize)

with open("labeled-urls-headers_all.csv", newline="") as csvfile:
reader = csv.DictReader(csvfile)
# result = sorted(reader, key=lambda d: int(d["id"]))
with open("train-urls.csv", "w", newline="") as writefile:
writer = csv.writer(writefile)
writer.writerow(["url", "label"])

vd_writer = csv.writer(open("test-urls.csv", "w", newline=""))
vd_writer.writerow(["url", "label"])

for row in reader:
label = row["label"]

if "#" in label:
continue

url = row["url"]

if not url:
continue

rand = random.randint(1, 13)

if label not in labels:
labels.add(label)
writer.writerow([url, row["label"]])
elif rand != 1:
writer.writerow([url, row["label"]])
else:
vd_writer.writerow([url, row["label"]])
28 changes: 28 additions & 0 deletions hugging_face/url_relevance/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Hugging Face URL Relevance Model

This model is trained using website data from a list of potentially relevant URLs.

A "relevant" URL is one that related to criminal justice. A "relevant" website does not necessarily mean it is a "good" data source.

The latest version of the model can be found here: [https://huggingface.co/PDAP/url-relevance](https://huggingface.co/PDAP/url-relevance)

## How to use

The training script requires `Python 3.10` or lower to install the dependencies.

1. `cd` into the root directory of the project
2. Create a virtual environment. In your terminal:
```commandline
python -m venv relevance-environment
source relevance-environment/bin/activate
```
3. Now install the required python libraries:
```commandline
$pip install -r requirements.txt
```
4. Run `python3 hugging_face/url_relevance/huggingface_relevance.py`

## Scripts

- `huggingface_relevance.py` - The training script for the model.
- `clean_data.py` - Cleans up raw website data from the tag collector so that it may be used for effective training.
11 changes: 11 additions & 0 deletions hugging_face/url_relevance/clean-data-example.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
url,url_path,label,html_title,meta_description,root_page_title,http_response,keywords,h1,h2,h3,h4,h5,h6
https://coloradosprings.gov/police-department/article/news/i-25-traffic-safety-deployment-after-stop,police-department/article/news/i-25-traffic-safety-deployment-after-stop,Relevant,I-25 Traffic Safety Deployment- After the Stop | City of Colorado Springs,"",Home Page | City of Colorado Springs,200,"['traffic safety deployment', '25 traffic safety', 'colorado springs traffic', 'springs traffic transportation', 'traffic safety', 'safety traffic', 'transportation engineering traffic', 'keeping roadways safe', 'traffic safety traffic', 'colorado springs drivers']","[""I-25 Traffic Safety Deployment- After the Stop""]","[""Search"", ""Colorado Springs Weekly"", ""GoCOS!"", ""Connect with @CityofCOS""]",[],[],"[""REPORT ONLINE""]",[]
http://www.longbeach.gov/police/press-releases/pursuit-with-traffic-fatality---3rd-street-and-temple-avenue/,police/press-releases/pursuit-with-traffic-fatality---3rd-street-and-temple-avenue/,Relevant,PURSUIT WITH TRAFFIC FATALITY - 3RD STREET AND TEMPLE AVENUE,"",City of Long Beach,200,"['pursuit traffic fatality', 'pursuit stolen vehicle', 'long beach police', 'vehicle victim', 'involved pursuit stolen', 'stolen vehicle victim', 'officers involved pursuit', 'suspect vehicle collided', 'traffic fatality 3rd', 'police department collision']","[""Long Beach Police Department""]",[],[],[],[],[]
http://www.ryepolice.us/logs/police-logs-for-6-3-20-6-9-20,logs/police-logs-for-6-3-20-6-9-20,Relevant,Rye Police Department Police Logs for 6/3/20-6/9/20 - Rye Police Department,"",Rye Police Department Welcome to the Rye Police Department - Rye Police Department,200,"['rye police department', 'rye police', '20 rye police', 'ordinances victim services', 'police department emergency', 'police logs 20', '2020 police logs', 'report request town', 'location address rye', 'department emergency 911']","[""Police Logs for 6/3/20-6/9/20""]","[""Police Logs for 6/3/20-6/9/20""]","[""Navigation"", ""Follow Us"", ""Facebook"", ""Email Updates"", ""Storm Preparedness Guide""]",[],[],[]
http://www.ryepolice.us/logs/police-logs-for-11216-11816,logs/police-logs-for-11216-11816,Relevant,Rye Police Department Police Logs for 11/2/16-11/8/16 - Rye Police Department,"",Rye Police Department Welcome to the Rye Police Department - Rye Police Department,200,"['rye police department', 'rye public safety', 'rye police', '16 rye police', 'ordinances victim services', 'police department emergency', 'report request town', 'department emergency 911', 'police department', 'department police']","[""Police Logs for 11/2/16-11/8/16""]","[""Police Logs for 11/2/16-11/8/16""]","[""Navigation"", ""Follow Us"", ""Facebook"", ""Email Updates"", ""Storm Preparedness Guide""]",[],[],[]
https://delcopa.gov/sheriff/pdf/atf_letter.pdf,sheriff/pdf/atf_letter.pdf,Relevant,"","","Delaware County, Pennsylvania",200,"","","","","","",""
https://www.mass.gov/event/jlmc-17-6105-watertown-police-association-3a-hearing-open-meeting-notice-05-09-2018-2018-05-09t100000-0400-2018-05-09t170000-0400,event/jlmc-17-6105-watertown-police-association-3a-hearing-open-meeting-notice-05-09-2018-2018-05-09t100000-0400-2018-05-09t170000-0400,Relevant,JLMC-17-6105 Watertown Police Association 3(A) Hearing Open Meeting Notice 05-09-2018 | Mass.gov,JLMC-17-6105 Watertown Police Association 3(A) Hearing Open Meeting Notice 05-09-2018,Mass.gov,200,"['government organization massachusetts', 'gov mass', 'mass gov', 'content mass gov', 'mass gov mass', 'massachusetts mass gov', 'commonwealth massachusetts', 'organization massachusetts', 'gov mass gov', 'massachusetts know official']","[""JLMC-17-6105 Watertown Police Association 3(A) Hearing Open Meeting Notice 05-09-2018""]","[""Address"", ""Overview of JLMC-17-6105 Watertown Police Association 3(A) Hearing Open Meeting Notice 05-09-2018"", ""Additional Resources for JLMC-17-6105 Watertown Police Association 3(A) Hearing Open Meeting Notice 05-09-2018"", ""Help Us Improve Mass.gov with your feedback""]",[],[],[],[]
https://ridgelandsc.gov/police-department/daily-arrest-reports-may,police-department/daily-arrest-reports-may,Relevant,Town of Ridgeland,Town of Ridgeland,Town of Ridgeland,200,"['maps community ridgeland', 'ridgeland home government', 'services town ridgeland', 'town ridgeland', 'town ridgeland website', 'town ridgeland farmers', 'community ridgeland', 'town ridgeland home', 'community ridgeland battle', 'ridgeland website']","[""Police Department""]","[""Daily Arrest Reports - May""]",[],[],[],[]
https://delcopa.gov/planning/pdf/demodata/minoritypopulation2020map.pdf,planning/pdf/demodata/minoritypopulation2020map.pdf,Irrelevant,"","","Delaware County, Pennsylvania",200,"","","","","","",""
https://www.mass.gov/doc/christine-kennedy-v-city-of-chicopee-school-dept/download,doc/christine-kennedy-v-city-of-chicopee-school-dept/download,Relevant,"","",Mass.gov,200,"","","","","","",""
https://www.providenceri.gov/hr/wellness/cop-manage-anxiety-9-11/,hr/wellness/cop-manage-anxiety-9-11/,Irrelevant,City of Providence CoP Manage Anxiety 9.11 - City of Providence,"",City of Providence Home - City of Providence,200,"['providence cop manage', 'city providence cop', 'providence cop', 'city providence', 'providence city hall', 'providence city', '11 city providence', 'providence providence', 'providence rhode', 'street providence']","[""CoP Manage Anxiety 9.11""]","[""Share this story"", ""Providence City Hall"", ""Follow Us on Social Media:""]","[""City Of Providence"", ""Mayor Brett Smiley"", ""SIGN UP FOR OUR WEEKLY E-NEWS | Haga clic aquí para español"", ""Lists *""]",[],[],[]
58 changes: 58 additions & 0 deletions hugging_face/url_relevance/clean_data.py
EvilDrPurple marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import ast

Check warning on line 1 in hugging_face/url_relevance/clean_data.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 Missing docstring in public module Raw Output: ./hugging_face/url_relevance/clean_data.py:1:1: D100 Missing docstring in public module
import csv
import sys
import os


""" This script cleans up raw website data from the tag collector so that it may be used for effective training.

Check failure on line 7 in hugging_face/url_relevance/clean_data.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (112 > 79 characters) Raw Output: ./hugging_face/url_relevance/clean_data.py:7:80: E501 line too long (112 > 79 characters)
It primarily merges list of strings into a single string, removing brackets and quotes from the string.

Check failure on line 8 in hugging_face/url_relevance/clean_data.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (107 > 79 characters) Raw Output: ./hugging_face/url_relevance/clean_data.py:8:80: E501 line too long (107 > 79 characters)
"""


csv.field_size_limit(sys.maxsize)
FILE = "clean-data-example.csv"

with open(FILE, newline="") as readFile, open("new.csv", "w", newline="") as writeFile:

Check failure on line 15 in hugging_face/url_relevance/clean_data.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (87 > 79 characters) Raw Output: ./hugging_face/url_relevance/clean_data.py:15:80: E501 line too long (87 > 79 characters)
reader = csv.DictReader(readFile)
fieldnames = [
"url",
"url_path",
"label",
"html_title",
"meta_description",
"root_page_title",
"http_response",
"keywords",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"div_text",
]
writer = csv.DictWriter(writeFile, fieldnames=fieldnames)
writer.writeheader()

for row in reader:
write_row = row

for key, value in write_row.items():
try:
val_list = ast.literal_eval(value)
except (SyntaxError, ValueError):
continue

# if key == "keywords":
# l = l[0:2]

try:
value = " ".join(val_list)
except TypeError:
continue

write_row.update({key: value})

writer.writerow(write_row)

os.rename("new.csv", FILE)
118 changes: 118 additions & 0 deletions hugging_face/url_relevance/huggingface_relevance.py
EvilDrPurple marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from datasets import load_dataset

Check warning on line 1 in hugging_face/url_relevance/huggingface_relevance.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 Missing docstring in public module Raw Output: ./hugging_face/url_relevance/huggingface_relevance.py:1:1: D100 Missing docstring in public module
from datasets import concatenate_datasets
from transformers import TrainingArguments, Trainer
from transformers import AutoTokenizer
from multimodal_transformers.model import AutoModelWithTabular, TabularConfig
from transformers import AutoConfig
from multimodal_transformers.data import load_data
import numpy as np
import pandas as pd
import evaluate


""" This model is trained using website data from a list of potentially relevant URLs.

Check failure on line 13 in hugging_face/url_relevance/huggingface_relevance.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (86 > 79 characters) Raw Output: ./hugging_face/url_relevance/huggingface_relevance.py:13:80: E501 line too long (86 > 79 characters)
A "relevant" URL is one that related to criminal justice. A "relevant" website does not necessarily mean it is a "good" data source.

Check failure on line 14 in hugging_face/url_relevance/huggingface_relevance.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (136 > 79 characters) Raw Output: ./hugging_face/url_relevance/huggingface_relevance.py:14:80: E501 line too long (136 > 79 characters)
The latest version of the model can be found here: https://huggingface.co/PDAP/url-relevance

Check failure on line 15 in hugging_face/url_relevance/huggingface_relevance.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (96 > 79 characters) Raw Output: ./hugging_face/url_relevance/huggingface_relevance.py:15:80: E501 line too long (96 > 79 characters)
"""

MODEL = "distilbert-base-uncased"
DATASET = "PDAP/urls-relevance"
MAX_STEPS = 1000


def str2int(label):

Check warning on line 23 in hugging_face/url_relevance/huggingface_relevance.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 Missing docstring in public function Raw Output: ./hugging_face/url_relevance/huggingface_relevance.py:23:1: D103 Missing docstring in public function
return labels.index(label)


dataset = load_dataset(DATASET)
dataset = concatenate_datasets([dataset["train"], dataset["test"]])
dataset = dataset.shuffle()
dataset = dataset.train_test_split(test_size=0.15)
train_df = pd.DataFrame(dataset["train"])
test_df = pd.DataFrame(dataset["test"])

labels = ["Relevant", "Irrelevant"]
num_labels = len(labels)
label_col = "label"
train_df["label"] = train_df["label"].apply(str2int)
test_df["label"] = test_df["label"].apply(str2int)

text_cols = [
"url_path",
"html_title",
"keywords",
"meta_description",
"root_page_title",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
] # "url", "http_response"
empty_text_values = ['[""]', None, "[]", '""']
tokenizer = AutoTokenizer.from_pretrained(MODEL)

train_dataset = load_data(
train_df,
text_cols,
tokenizer,
label_col,
label_list=labels,
sep_text_token_str=tokenizer.sep_token,
empty_text_values=empty_text_values,
)
test_dataset = load_data(
test_df,
text_cols,
tokenizer,
label_col,
label_list=labels,
sep_text_token_str=tokenizer.sep_token,
empty_text_values=empty_text_values,
)

config = AutoConfig.from_pretrained(MODEL)
tabular_config = TabularConfig(
num_labels=num_labels,
combine_feat_method="text_only",
)
config.tabular_config = tabular_config
config.max_position_embeddings = 2048

model = AutoModelWithTabular.from_pretrained(MODEL, config=config, ignore_mismatched_sizes=True)

Check failure on line 83 in hugging_face/url_relevance/huggingface_relevance.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 line too long (96 > 79 characters) Raw Output: ./hugging_face/url_relevance/huggingface_relevance.py:83:80: E501 line too long (96 > 79 characters)

metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):

Check warning on line 88 in hugging_face/url_relevance/huggingface_relevance.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 Missing docstring in public function Raw Output: ./hugging_face/url_relevance/huggingface_relevance.py:88:1: D103 Missing docstring in public function
logits, labels = eval_pred
# logits_shape = logits[0].shape if isinstance(logits, tuple) else logits.shape
predictions = np.argmax(logits[0], axis=-1) if isinstance(logits, tuple) else np.argmax(logits, axis=-1)
labels = labels.flatten()
predictions = predictions.flatten()

return metric.compute(predictions=predictions, references=labels)


training_args = TrainingArguments(
output_dir="./url_relevance",
logging_dir="./url_relevance/runs",
overwrite_output_dir=True,
do_train=True,
max_steps=MAX_STEPS,
evaluation_strategy="steps",
eval_steps=25,
logging_steps=25,
weight_decay=0.1,
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=compute_metrics,
)

trainer.train()
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ requests~=2.31.0
polars~=0.20.10
python-dotenv~=1.0.1
bs4~=0.0.2
tqdm~=4.66.2
pytest~=8.0.1
tqdm>=4.64.1
pytest>=7.2.2
pytest-mock==3.12.0
urllib3~=1.26.18
psycopg2-binary~=2.9.6
Expand All @@ -18,6 +18,7 @@ transformers>=4.38.0
datasets>=2.17.1
accelerate>=0.27.2
numpy>=1.26.4
multimodal-transformers>=0.3.1
# html_tag_collector_only
requests_html>=0.10.0
lxml>=5.1.0
Expand Down
Loading