Skip to content

Commit

Permalink
imprrove vision test
Browse files Browse the repository at this point in the history
  • Loading branch information
MateoLostanlen committed Jul 15, 2024
1 parent d7cfe4e commit 3ced366
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 15 deletions.
24 changes: 12 additions & 12 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,12 @@
MODEL_URL_FOLDER = "https://huggingface.co/pyronear/yolov8s/resolve/main/"
MODEL_ID = "pyronear/yolov8s"
MODEL_NAME = "yolov8s.pt"
METADATA_PATH = "data/model_metadata.json"
METADATA_NAME = "model_metadata.json"


logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True)


def is_arm_architecture():
# Check for ARM architecture
return platform.machine().startswith("arm") or platform.machine().startswith("aarch")


# Utility function to save metadata
def save_metadata(metadata_path, metadata):
with open(metadata_path, "w") as f:
Expand All @@ -54,7 +49,7 @@ class Classifier:
def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format="ncnn", model_path=None) -> None:

Check warning on line 49 in pyroengine/vision.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

pyroengine/vision.py#L49

Redefining built-in 'format'
if model_path is None:
if format == "ncnn":
if is_arm_architecture():
if self.is_arm_architecture():
model = "yolov8s_ncnn_model.zip"
else:
logging.info("NCNN format is optimized for arm architecture only, switching to onnx")
Expand All @@ -63,6 +58,7 @@ def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format=
model = f"yolov8s.{format}"

model_path = os.path.join(model_folder, model)
metadata_path = os.path.join(model_folder, METADATA_NAME)
model_url = MODEL_URL_FOLDER + model

# Get the expected SHA256 from Hugging Face
Expand All @@ -76,15 +72,15 @@ def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format=
# Check if the model file exists
if os.path.isfile(model_path):
# Load existing metadata
metadata = self.load_metadata(METADATA_PATH)
metadata = self.load_metadata(metadata_path)
if metadata and metadata.get("sha256") == expected_sha256:
logging.info("Model already exists and the SHA256 hash matches. No download needed.")
else:
logging.info("Model exists but the SHA256 hash does not match or the file doesn't exist.")
os.remove(model_path)
self.download_model(model_url, model_path, expected_sha256)
self.download_model(model_url, model_path, expected_sha256, metadata_path)
else:
self.download_model(model_url, model_path, expected_sha256)
self.download_model(model_url, model_path, expected_sha256, metadata_path)

file_name, ext = os.path.splitext(model_path)
if ext == ".zip":
Expand All @@ -97,14 +93,18 @@ def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format=
self.conf = conf
self.iou = iou

def is_arm_architecture(self):
# Check for ARM architecture
return platform.machine().startswith("arm") or platform.machine().startswith("aarch")

def get_sha(self, siblings):
# Extract the SHA256 hash from the model files metadata
for file in siblings:
if file.rfilename == os.path.basename(MODEL_NAME):
return file.lfs["sha256"]
return None

Check warning on line 105 in pyroengine/vision.py

View check run for this annotation

Codecov / codecov/patch

pyroengine/vision.py#L105

Added line #L105 was not covered by tests

def download_model(self, model_url, model_path, expected_sha256):
def download_model(self, model_url, model_path, expected_sha256, metadata_path):
# Ensure the directory exists
os.makedirs(os.path.split(model_path)[0], exist_ok=True)

Expand All @@ -116,7 +116,7 @@ def download_model(self, model_url, model_path, expected_sha256):

# Save the metadata
metadata = {"sha256": expected_sha256}
save_metadata(METADATA_PATH, metadata)
save_metadata(metadata_path, metadata)
logging.info("Metadata saved!")

# Utility function to load metadata
Expand Down
67 changes: 64 additions & 3 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
import datetime
import os
from unittest.mock import patch

import numpy as np

from pyroengine.vision import Classifier

METADATA_PATH = "data/model_metadata.json"
model_path = "data/yolov8s.onnx"
sha = "9f1b1c2654d98bbed91e514ce20ea73a0a5fbd1111880f230d516ed40ea2dc58"

def get_creation_date(file_path):
if os.path.exists(file_path):

# For Unix-like systems
stat = os.stat(file_path)
try:
creation_time = stat.st_birthtime
except AttributeError:
# On Unix, use the last modification time as a fallback
creation_time = stat.st_mtime

creation_date = datetime.datetime.fromtimestamp(creation_time)
return creation_date
else:
return None


def test_classifier(tmpdir_factory, mock_wildfire_image):
Expand All @@ -19,7 +36,10 @@ def test_classifier(tmpdir_factory, mock_wildfire_image):
conf = np.max(out[:, 4])
assert 0 <= conf <= 1

# Test onnx model
model = Classifier(model_folder=folder, format="onnx")
model_path = os.path.join(folder, "yolov8s.onnx")
assert os.path.isfile(model_path)

# Test mask
mask = np.ones((384, 640))
Expand All @@ -29,3 +49,44 @@ def test_classifier(tmpdir_factory, mock_wildfire_image):
mask = np.zeros((384, 640))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)

# Test dl pt model
_ = Classifier(model_folder=folder, format="pt")
model_path = os.path.join(folder, "yolov8s.pt")
assert os.path.isfile(model_path)

# Test dl ncnn model
with patch.object(Classifier, "is_arm_architecture", return_value=True):
_ = Classifier(model_folder=folder)
model_path = os.path.join(folder, "yolov8s_ncnn_model")
assert os.path.isdir(model_path)


def test_download(tmpdir_factory):
print("test_classifier")
folder = str(tmpdir_factory.mktemp("engine_cache"))

# Instantiate the ONNX model
model = Classifier(model_folder=folder)

model_path = os.path.join(folder, "yolov8s.onnx")
model_creation_date = get_creation_date(model_path)

# No download if exist
_ = Classifier(model_folder=folder)
model_creation_date2 = get_creation_date(model_path)
assert model_creation_date == model_creation_date2

# Download if does not exist
os.remove(model_path)
_ = Classifier(model_folder=folder)
model_creation_date3 = get_creation_date(model_path)
print(model_creation_date, model_creation_date3)
assert model_creation_date != model_creation_date3

# Download if sha is not the same
with patch.object(Classifier, "get_sha", return_value="sha12"):
_ = Classifier(model_folder=folder)
model_creation_date4 = get_creation_date(model_path)
print(model_creation_date, model_creation_date3)
assert model_creation_date4 != model_creation_date3

0 comments on commit 3ced366

Please sign in to comment.