Skip to content

Commit

Permalink
Merge pull request #27 from guardrails-ai/fix/nltk-dataset
Browse files Browse the repository at this point in the history
Fix NLTK dataset
  • Loading branch information
AlejandroEsquivel authored Nov 5, 2024
2 parents e1ff384 + 543380d commit 6a3660e
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion validator/post-install.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,55 @@
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine

def load_nltk_data():
import re
import nltk
from importlib.metadata import version

nltk_version = version("nltk")
nltk_breaking_version = "3.8.2" # The version where the dataset changed

def parse_major_minor_patch(version: str):
"""Extract the major, minor, and patch version numbers from a version string."""
match = re.match(r"^(0|[1-9]\d*)\.(0|[1-9]\d*)(?:\.(0|[1-9]\d*))?(?:[-+][0-9A-Za-z-.]+)?$", version)
if match:
major = int(match.group(1))
minor = int(match.group(2))
patch = int(match.group(3)) if match.group(3) else 0 # Default to 0 if patch is not provided
return major, minor, patch
else:
raise ValueError(f"Invalid semantic version: '{version}'")

def install_pre_382_dataset():
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")

def install_post_382_dataset():
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
nltk.download("punkt_tab")

try:
target_major, target_minor, target_patch = parse_major_minor_patch(nltk_breaking_version)
major, minor, patch = parse_major_minor_patch(nltk_version)

if (major, minor, patch) >= (target_major, target_minor, target_patch):
install_post_382_dataset()
elif (major, minor, patch) < (target_major, target_minor, target_patch):
install_pre_382_dataset()
except Exception:
print((
"Error auto-installing nltk dataset, please install manually.\n"
"This can be done with:\n",
"Version < 3.8.2:\n import nltk\n nltk.download('punkt')",
"Version >= 3.8.2:\n import nltk\n nltk.download('punkt_tab')"
))

load_nltk_data()

# Download models
AnalyzerEngine()
AnonymizerEngine()
AnonymizerEngine()

0 comments on commit 6a3660e

Please sign in to comment.