From d284a630bccd2551ef1d6e72a6e562acdccdfd5b Mon Sep 17 00:00:00 2001 From: Alextopher Date: Mon, 1 Jul 2024 22:34:24 -0400 Subject: [PATCH 01/13] rewrite NeuralNetworkModelManager to fascilate tracking more models --- .../NeuralNetworkModelManager.java | 251 +++++++++++++++--- .../org/photonvision/jni/RknnDetectorJNI.java | 5 + .../vision/pipe/impl/RknnDetectionPipe.java | 8 +- .../src/main/java/org/photonvision/Main.java | 5 +- ...v5.txt => note-640-640-yolov5s-labels.txt} | 0 5 files changed, 217 insertions(+), 52 deletions(-) rename photon-server/src/main/resources/models/{labels_v5.txt => note-640-640-yolov5s-labels.txt} (100%) diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java index 75a1ded750..ba6c2267bd 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java @@ -18,24 +18,49 @@ package org.photonvision.common.configuration; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; +import java.net.URL; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; +import java.util.ArrayList; import java.util.List; import org.photonvision.common.logging.LogGroup; import org.photonvision.common.logging.Logger; import org.photonvision.rknn.RknnJNI; +/** + * Manages the loading of neural network models. + * + *

Models are loaded from the filesystem at the modelsFolder location. PhotonVision + * also supports shipping pre-trained models as resources in the JAR. If the model has already been + * extracted to the filesystem, it will not be extracted again. + * + *

Each model must have a corresponding labels file. The labels file format is + * simply a list of string names per label, one label per line. The labels file must have the same + * name as the model file, but with the suffix -labels.txt instead of .rknn + * . + * + *

Note: PhotonVision currently only supports YOLOv5 and YOLOv8 models in the .rknn + * format. + */ public class NeuralNetworkModelManager { + /** Singleton instance of the NeuralNetworkModelManager */ private static NeuralNetworkModelManager INSTANCE; - private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config); - private final String MODEL_NAME = "note-640-640-yolov5s.rknn"; - private final RknnJNI.ModelVersion modelVersion = RknnJNI.ModelVersion.YOLO_V5; - private File defaultModelFile; - private List labels; + /** + * Private constructor to prevent instantiation + * + * @return The NeuralNetworkModelManager instance + */ + private NeuralNetworkModelManager() {} + /** + * Returns the singleton instance of the NeuralNetworkModelManager + * + * @return The singleton instance + */ public static NeuralNetworkModelManager getInstance() { if (INSTANCE == null) { INSTANCE = new NeuralNetworkModelManager(); @@ -43,62 +68,200 @@ public static NeuralNetworkModelManager getInstance() { return INSTANCE; } + /** Logger for the NeuralNetworkModelManager */ + private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config); + + /** + * Determines the model version based on the model's filename. + * + *

"yolov5" -> "YOLO_V5" + * + *

"yolov8" -> "YOLO_V8" + * + * @param modelName The model's filename + * @return The model version + */ + private static RknnJNI.ModelVersion getModelVersion(String modelName) + throws IllegalArgumentException { + if (modelName.contains("yolov5")) { + return RknnJNI.ModelVersion.YOLO_V5; + } else if (modelName.contains("yolov8")) { + return RknnJNI.ModelVersion.YOLO_V8; + } else { + throw new IllegalArgumentException("Unknown model version for model " + modelName); + } + } + + /** This class represents a model that can be loaded by the RknnJNI. */ + public class Model { + public final File modelFile; + public final RknnJNI.ModelVersion version; + public final List labels; + + public Model(String model, String labels) throws IllegalArgumentException { + this.version = getModelVersion(model); + this.modelFile = new File(model); + try { + this.labels = Files.readAllLines(Paths.get(labels)); + } catch (IOException e) { + throw new IllegalArgumentException("Error reading labels file " + labels, e); + } + } + + public String getPath() { + return modelFile.getAbsolutePath(); + } + } + + /** + * Stores model information, such as the model file, labels, and version. + * + *

The first model in the list is the default model. + */ + private List models; + /** - * Perform initial setup and extract default model from JAR to the filesystem + * Returns the default rknn model. This is simply the first model in the list. * - * @param modelsFolder Where models live + * @return The default model */ - public void initialize(File modelsFolder) { - var modelResourcePath = "/models/" + MODEL_NAME; - this.defaultModelFile = new File(modelsFolder, MODEL_NAME); - extractResource(modelResourcePath, defaultModelFile); + public Model getDefaultRknnModel() { + return models.get(0); + } - File labelsFile = new File(modelsFolder, "labels_v5.txt"); - var labelResourcePath = "/models/" + labelsFile.getName(); - extractResource(labelResourcePath, labelsFile); + /** + * Enumerates the names of all models. + * + * @return A list of model names + */ + public List getModels() { + return models.stream().map(model -> model.modelFile.getName()).toList(); + } + + /** + * Returns the model with the given name. + * + *

TODO: Java 17 This should return an Optional instead of null. + * + * @param modelName The model name + * @return The model + */ + public Model getModel(String modelName) { + Model m = + models.stream() + .filter(model -> model.modelFile.getName().equals(modelName)) + .findFirst() + .orElse(null); + + if (m == null) { + logger.error("Model " + modelName + " not found."); + } + + return m; + } + + /** + * Loads models from the specified folder. + * + * @param modelsFolder The folder where the models are stored + */ + public void loadModels(File modelsFolder) { + if (!modelsFolder.exists()) { + logger.error("Models folder " + modelsFolder.getAbsolutePath() + " does not exist."); + return; + } + + if (models == null) { + models = new ArrayList<>(); + } try { - labels = Files.readAllLines(Paths.get(labelsFile.getPath())); + Files.walk(modelsFolder.toPath()) + .filter(Files::isRegularFile) + .filter(path -> path.toString().endsWith(".rknn")) + .forEach( + modelPath -> { + String model = modelPath.toString(); + String labels = model.replace(".rknn", "-labels.txt"); + + try { + models.add(new Model(model, labels)); + } catch (IllegalArgumentException e) { + logger.error("Failed to load model " + model, e); + } + }); } catch (IOException e) { - logger.error("Error reading labels.txt", e); + logger.error("Failed to load models from " + modelsFolder.getAbsolutePath(), e); + } + + // Log the loaded models + StringBuilder sb = new StringBuilder(); + sb.append("Loaded models: "); + for (Model model : models) { + sb.append(model.modelFile.getName()).append(", "); } + sb.setLength(sb.length() - 2); + logger.info(sb.toString()); } - private void extractResource(String resourcePath, File outputFile) { - try (var in = NeuralNetworkModelManager.class.getResourceAsStream(resourcePath)) { - if (in == null) { + /** + * Extracts models from a JAR resource and copies them to the specified folder. + * + * @param modelsFolder the folder where the models will be copied to + */ + public void extractModels(File modelsFolder) { + if (!modelsFolder.exists()) { + modelsFolder.mkdirs(); + } + + String resourcePath = "models"; // Adjust path if necessary + try { + URL resourceURL = NeuralNetworkModelManager.class.getClassLoader().getResource(resourcePath); + if (resourceURL == null) { logger.error("Failed to find jar resource at " + resourcePath); return; } - if (!outputFile.exists()) { - try (FileOutputStream fos = new FileOutputStream(outputFile)) { - int read = -1; - byte[] buffer = new byte[1024]; - while ((read = in.read(buffer)) != -1) { - fos.write(buffer, 0, read); + Path resourcePathResolved = Paths.get(resourceURL.toURI()); + Files.walk(resourcePathResolved) + .forEach(sourcePath -> copyResource(sourcePath, resourcePathResolved, modelsFolder)); + } catch (Exception e) { + logger.error("Failed to extract models from JAR", e); + } + } + + /** + * Copies a resource from the source path to the target path. + * + * @param sourcePath The path of the resource to be copied. + * @param resourcePathResolved The resolved path of the resource. + * @param modelsFolder The folder where the resource will be copied to. + */ + private void copyResource(Path sourcePath, Path resourcePathResolved, File modelsFolder) { + Path targetPath = + Paths.get( + modelsFolder.getAbsolutePath(), resourcePathResolved.relativize(sourcePath).toString()); + try { + if (Files.isDirectory(sourcePath)) { + Files.createDirectories(targetPath); + } else { + Path parentDir = targetPath.getParent(); + if (parentDir != null && !Files.exists(parentDir)) { + Files.createDirectories(parentDir); + } + + if (!Files.exists(targetPath)) { + Files.copy(sourcePath, targetPath); + } else { + long sourceSize = Files.size(sourcePath); + long targetSize = Files.size(targetPath); + if (sourceSize != targetSize) { + Files.copy(sourcePath, targetPath, StandardCopyOption.REPLACE_EXISTING); } - } catch (IOException e) { - logger.error("Error extracting resource to " + outputFile.toPath().toString(), e); } - } else { - logger.info( - "File " + outputFile.toPath().toString() + " already exists. Skipping extraction."); } } catch (IOException e) { - logger.error("Error finding jar resource " + resourcePath, e); + logger.error("Failed to copy " + sourcePath + " to " + targetPath, e); } } - - public File getDefaultRknnModel() { - return defaultModelFile; - } - - public List getLabels() { - return labels; - } - - public RknnJNI.ModelVersion getModelVersion() { - return modelVersion; - } } diff --git a/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java b/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java index c282ec7a79..a247818052 100644 --- a/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java +++ b/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java @@ -23,6 +23,7 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.stream.Collectors; import org.opencv.core.Mat; +import org.photonvision.common.configuration.NeuralNetworkModelManager; import org.photonvision.common.logging.LogGroup; import org.photonvision.common.logging.Logger; import org.photonvision.common.util.TestUtils; @@ -70,6 +71,10 @@ public static class RknnObjectDetector { static volatile boolean hook = false; + public RknnObjectDetector(NeuralNetworkModelManager.Model model) { + this(model.getPath(), model.labels, model.version); + } + public RknnObjectDetector(String modelPath, List labels, RknnJNI.ModelVersion version) { synchronized (lock) { objPointer = RknnJNI.create(modelPath, labels.size(), version.ordinal(), -1); diff --git a/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java b/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java index 9ce2f348f4..70d02e649a 100644 --- a/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java +++ b/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java @@ -40,14 +40,10 @@ public class RknnDetectionPipe public RknnDetectionPipe() { // For now this is hard-coded to defaults. Should be refactored into set pipe - // params, though. - // And ideally a little wrapper helper for only changing native stuff on content + // params, though. And ideally a little wrapper helper for only changing native stuff on content // change created. this.detector = - new RknnObjectDetector( - NeuralNetworkModelManager.getInstance().getDefaultRknnModel().getAbsolutePath(), - NeuralNetworkModelManager.getInstance().getLabels(), - NeuralNetworkModelManager.getInstance().getModelVersion()); + new RknnObjectDetector(NeuralNetworkModelManager.getInstance().getDefaultRknnModel()); } private static class Letterbox { diff --git a/photon-server/src/main/java/org/photonvision/Main.java b/photon-server/src/main/java/org/photonvision/Main.java index 9dd5e7ad70..ab3b15aadc 100644 --- a/photon-server/src/main/java/org/photonvision/Main.java +++ b/photon-server/src/main/java/org/photonvision/Main.java @@ -435,8 +435,9 @@ public static void main(String[] args) { .setConfig(ConfigManager.getInstance().getConfig().getNetworkConfig()); logger.info("Loading ML models"); - NeuralNetworkModelManager.getInstance() - .initialize(ConfigManager.getInstance().getModelsDirectory()); + var modelManager = NeuralNetworkModelManager.getInstance(); + modelManager.extractModels(ConfigManager.getInstance().getModelsDirectory()); + modelManager.loadModels(ConfigManager.getInstance().getModelsDirectory()); if (isSmoketest) { logger.info("PhotonVision base functionality loaded -- smoketest complete"); diff --git a/photon-server/src/main/resources/models/labels_v5.txt b/photon-server/src/main/resources/models/note-640-640-yolov5s-labels.txt similarity index 100% rename from photon-server/src/main/resources/models/labels_v5.txt rename to photon-server/src/main/resources/models/note-640-640-yolov5s-labels.txt From 2228c34a370554d03772826a4e66952ba81cc707 Mon Sep 17 00:00:00 2001 From: Alextopher Date: Tue, 2 Jul 2024 09:00:19 -0400 Subject: [PATCH 02/13] rewrite rknn object detector and update rknn detection pipe to respond to model changes --- .../NeuralNetworkModelManager.java | 30 ++++- .../org/photonvision/jni/RknnDetectorJNI.java | 101 ---------------- .../photonvision/jni/RknnObjectDetector.java | 114 ++++++++++++++++++ .../vision/pipe/impl/RknnDetectionPipe.java | 65 +++++++--- .../pipeline/ObjectDetectionPipeline.java | 3 +- .../ObjectDetectionPipelineSettings.java | 1 + 6 files changed, 191 insertions(+), 123 deletions(-) create mode 100644 photon-core/src/main/java/org/photonvision/jni/RknnObjectDetector.java diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java index ba6c2267bd..6877c49d87 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java @@ -26,6 +26,7 @@ import java.nio.file.StandardCopyOption; import java.util.ArrayList; import java.util.List; +import org.opencv.core.Size; import org.photonvision.common.logging.LogGroup; import org.photonvision.common.logging.Logger; import org.photonvision.rknn.RknnJNI; @@ -97,15 +98,36 @@ public class Model { public final File modelFile; public final RknnJNI.ModelVersion version; public final List labels; + public final Size inputSize; + /** + * Model constructor. + * + * @param model format `name-width-height-model.format` + * @param labels + * @throws IllegalArgumentException + */ public Model(String model, String labels) throws IllegalArgumentException { - this.version = getModelVersion(model); + String[] parts = model.split("-"); + if (parts.length != 4) { + throw new IllegalArgumentException("Invalid model file name: " + model); + } + + // TODO: model 'version' need to be replaced the by the product of 'Version' x 'Format' + this.version = getModelVersion(parts[3]); + + int width = Integer.parseInt(parts[1]); + int height = Integer.parseInt(parts[2]); + this.inputSize = new Size(width, height); + this.modelFile = new File(model); try { this.labels = Files.readAllLines(Paths.get(labels)); } catch (IOException e) { - throw new IllegalArgumentException("Error reading labels file " + labels, e); + throw new IllegalArgumentException("Failed to read labels file " + labels, e); } + + logger.info("Loaded model " + modelFile.getName()); } public String getPath() { @@ -141,7 +163,7 @@ public List getModels() { /** * Returns the model with the given name. * - *

TODO: Java 17 This should return an Optional instead of null. + *

TODO: Java 17 This should return an Optional Model instead of null. * * @param modelName The model name * @return The model @@ -214,7 +236,7 @@ public void extractModels(File modelsFolder) { modelsFolder.mkdirs(); } - String resourcePath = "models"; // Adjust path if necessary + String resourcePath = "models"; try { URL resourceURL = NeuralNetworkModelManager.class.getClassLoader().getResource(resourcePath); if (resourceURL == null) { diff --git a/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java b/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java index a247818052..6ad21fd9cd 100644 --- a/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java +++ b/photon-core/src/main/java/org/photonvision/jni/RknnDetectorJNI.java @@ -18,21 +18,10 @@ package org.photonvision.jni; import java.io.IOException; -import java.util.Arrays; import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.stream.Collectors; -import org.opencv.core.Mat; -import org.photonvision.common.configuration.NeuralNetworkModelManager; -import org.photonvision.common.logging.LogGroup; -import org.photonvision.common.logging.Logger; import org.photonvision.common.util.TestUtils; -import org.photonvision.rknn.RknnJNI; -import org.photonvision.rknn.RknnJNI.RknnResult; -import org.photonvision.vision.pipe.impl.NeuralNetworkPipeResult; public class RknnDetectorJNI extends PhotonJNICommon { - private static final Logger logger = new Logger(RknnDetectorJNI.class, LogGroup.General); private boolean isLoaded; private static RknnDetectorJNI instance = null; @@ -61,94 +50,4 @@ public boolean isLoaded() { public void setLoaded(boolean state) { isLoaded = state; } - - public static class RknnObjectDetector { - long objPointer = -1; - private List labels; - private final Object lock = new Object(); - private static final CopyOnWriteArrayList detectors = - new CopyOnWriteArrayList<>(); - - static volatile boolean hook = false; - - public RknnObjectDetector(NeuralNetworkModelManager.Model model) { - this(model.getPath(), model.labels, model.version); - } - - public RknnObjectDetector(String modelPath, List labels, RknnJNI.ModelVersion version) { - synchronized (lock) { - objPointer = RknnJNI.create(modelPath, labels.size(), version.ordinal(), -1); - detectors.add(this); - logger.debug( - "Created detector " - + objPointer - + " from path " - + modelPath - + "! Detectors: " - + Arrays.toString(detectors.toArray())); - } - this.labels = labels; - - // the kernel should probably alredy deal with this for us, but I'm gunna be paranoid anyways. - if (!hook) { - Runtime.getRuntime() - .addShutdownHook( - new Thread( - () -> { - System.err.println("Shutdown hook rknn"); - for (var d : detectors) { - d.release(); - } - })); - hook = true; - } - } - - public List getClasses() { - return labels; - } - - /** - * Detect forwards using this model - * - * @param in The image to process - * @param nmsThresh Non-maximum supression threshold. Probably should not change - * @param boxThresh Minimum confidence for a box to be added. Basically just confidence - * threshold - */ - public List detect(Mat in, double nmsThresh, double boxThresh) { - RknnResult[] ret; - synchronized (lock) { - // We can technically be asked to detect and the lock might be acquired _after_ release has - // been called. This would mean objPointer would be invalid which would call everything to - // explode. - if (objPointer > 0) { - ret = RknnJNI.detect(objPointer, in.getNativeObjAddr(), nmsThresh, boxThresh); - } else { - logger.warn("Detect called after destroy -- giving up"); - return List.of(); - } - } - if (ret == null) { - return List.of(); - } - return List.of(ret).stream() - .map(it -> new NeuralNetworkPipeResult(it.rect, it.class_id, it.conf)) - .collect(Collectors.toList()); - } - - public void release() { - synchronized (lock) { - if (objPointer > 0) { - RknnJNI.destroy(objPointer); - detectors.remove(this); - System.out.println( - "Killed " + objPointer + "! Detectors: " + Arrays.toString(detectors.toArray())); - objPointer = -1; - } else { - logger.error("RKNN Detector has already been destroyed!"); - } - } - } - } } diff --git a/photon-core/src/main/java/org/photonvision/jni/RknnObjectDetector.java b/photon-core/src/main/java/org/photonvision/jni/RknnObjectDetector.java new file mode 100644 index 0000000000..dd065a96dc --- /dev/null +++ b/photon-core/src/main/java/org/photonvision/jni/RknnObjectDetector.java @@ -0,0 +1,114 @@ +package org.photonvision.jni; + +import java.lang.ref.Cleaner; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import org.opencv.core.Mat; +import org.photonvision.common.configuration.NeuralNetworkModelManager; +import org.photonvision.common.logging.LogGroup; +import org.photonvision.common.logging.Logger; +import org.photonvision.rknn.RknnJNI; +import org.photonvision.vision.opencv.Releasable; +import org.photonvision.vision.pipe.impl.NeuralNetworkPipeResult; + +/** + * A class to represent an object detector using the Rknn library. + * + *

TODO: When we start supporting more platforms, we should consider moving most of this code + * into a common "ObjectDetector" class to define the common interface for all object detectors. + */ +public class RknnObjectDetector implements Releasable { + /** logger for the RknnObjectDetector */ + private static final Logger logger = new Logger(RknnDetectorJNI.class, LogGroup.General); + + /** Cleaner instance to release the detector when it is no longer needed */ + private final Cleaner cleaner = Cleaner.create(); + + /** Pointer to the native object */ + private final long objPointer; + + /** Model configuration */ + private final NeuralNetworkModelManager.Model model; + + /** Returns the model used by the detector. */ + public NeuralNetworkModelManager.Model getModel() { + return model; + } + + /** Atomic boolean to ensure that the detector is only released _once_. */ + private AtomicBoolean released = new AtomicBoolean(false); + + /** + * Creates a new RknnObjectDetector from the given model. + * + * @param model The model to create the detector from. + */ + public RknnObjectDetector(NeuralNetworkModelManager.Model model) { + this.model = model; + + // Create the detector + objPointer = RknnJNI.create(model.getPath(), model.labels.size(), model.version.ordinal(), -1); + if (objPointer <= 0) { + throw new RuntimeException("Failed to create detector from path " + model.getPath()); + } + + logger.debug("Created detector for model " + model.modelFile.getName()); + + // Register the cleaner to release the detector when it goes out of scope + cleaner.register(this, this::release); + + // Set the detector to be released when the JVM exits + Runtime.getRuntime().addShutdownHook(new Thread(this::release)); + } + + /** + * Returns the classes that the detector can detect + * + * @return The classes + */ + public List getClasses() { + return model.labels; + } + + /** + * Detects objects in the given input image using the RknnDetector. + * + * @param in The input image to perform object detection on. + * @param nmsThresh The threshold value for non-maximum suppression. + * @param boxThresh The threshold value for bounding box detection. + * @return A list of NeuralNetworkPipeResult objects representing the detected objects. Returns an + * empty list if the detector is not initialized or if no objects are detected. + */ + public List detect(Mat in, double nmsThresh, double boxThresh) { + if (objPointer <= 0) { + // Report error and make sure to include the model name + logger.error("Detector is not initialized! Model: " + model.modelFile.getName()); + return List.of(); + } + + var results = RknnJNI.detect(objPointer, in.getNativeObjAddr(), nmsThresh, boxThresh); + if (results == null) { + return List.of(); + } + + return List.of(results).stream() + .map(it -> new NeuralNetworkPipeResult(it.rect, it.class_id, it.conf)) + .toList(); + } + + /** Thread-safe method to release the detector. */ + @Override + public void release() { + if (released.compareAndSet(false, true)) { + if (objPointer <= 0) { + logger.error( + "Detector is not initialized, and so it can't be released! Model: " + + model.modelFile.getName()); + return; + } + + RknnJNI.destroy(objPointer); + logger.debug("Released detector for model " + model.modelFile.getName()); + } + } +} diff --git a/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java b/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java index 70d02e649a..a74b45e464 100644 --- a/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java +++ b/photon-core/src/main/java/org/photonvision/vision/pipe/impl/RknnDetectionPipe.java @@ -27,23 +27,29 @@ import org.opencv.core.Size; import org.opencv.imgproc.Imgproc; import org.photonvision.common.configuration.NeuralNetworkModelManager; +import org.photonvision.common.configuration.NeuralNetworkModelManager.Model; import org.photonvision.common.util.ColorHelper; -import org.photonvision.jni.RknnDetectorJNI.RknnObjectDetector; +import org.photonvision.jni.RknnObjectDetector; import org.photonvision.vision.opencv.CVMat; import org.photonvision.vision.opencv.Releasable; import org.photonvision.vision.pipe.CVPipe; +/** + * A pipe that uses an rknn model to detect objects in an image. + * + *

TODO: This class should be refactored into a generic "ObjectDetectionPipe" that can use any + * "ObjectDetector" implementation. + */ public class RknnDetectionPipe extends CVPipe, RknnDetectionPipe.RknnDetectionPipeParams> implements Releasable { + private RknnObjectDetector detector; public RknnDetectionPipe() { - // For now this is hard-coded to defaults. Should be refactored into set pipe - // params, though. And ideally a little wrapper helper for only changing native stuff on content - // change created. - this.detector = - new RknnObjectDetector(NeuralNetworkModelManager.getInstance().getDefaultRknnModel()); + // Default model + Model model = NeuralNetworkModelManager.getInstance().getDefaultRknnModel(); + this.detector = new RknnObjectDetector(model); } private static class Letterbox { @@ -60,28 +66,40 @@ public Letterbox(double dx, double dy, double scale) { @Override protected List process(CVMat in) { - var frame = in.getMat(); + // Check if the model has changed + if (detector.getModel() != params.model) { + detector.release(); + detector = new RknnObjectDetector(params.model); + } - // Make sure we don't get a weird empty frame + Mat frame = in.getMat(); if (frame.empty()) { return List.of(); } - // letterbox - var letterboxed = new Mat(); - var scale = - letterbox(frame, letterboxed, new Size(640, 640), ColorHelper.colorToScalar(Color.GRAY)); - - if (letterboxed.width() != 640 || letterboxed.height() != 640) { - // huh whack give up lol - throw new RuntimeException("RGA bugged but still wrong size"); + // Resize the frame to the input size of the model + Size shape = this.params.model.inputSize; + Mat letterboxed = new Mat(); + Letterbox scale = letterbox(frame, letterboxed, shape, ColorHelper.colorToScalar(Color.GRAY)); + if (!letterboxed.size().equals(shape)) { + throw new RuntimeException("Letterboxed frame is not the right size!"); } - var ret = detector.detect(letterboxed, params.nms, params.confidence); + + // Detect objects in the letterboxed frame + List ret = detector.detect(letterboxed, params.nms, params.confidence); letterboxed.release(); + // Resize the detections to the original frame size return resizeDetections(ret, scale); } + /** + * Resizes the detections to the original frame size. + * + * @param unscaled The detections to resize + * @param letterbox The letterbox information + * @return The resized detections + */ private List resizeDetections( List unscaled, Letterbox letterbox) { var ret = new ArrayList(); @@ -101,6 +119,18 @@ private List resizeDetections( return ret; } + /** + * Resize the frame to the new shape and "letterbox" it. + * + *

Letterboxing is the process of resizing an image to a new shape while maintaining the aspect + * ratio of the original image. The new image is padded with a color to fill the remaining space. + * + * @param frame + * @param letterboxed + * @param newShape + * @param color + * @return + */ private static Letterbox letterbox(Mat frame, Mat letterboxed, Size newShape, Scalar color) { // from https://github.com/ultralytics/yolov5/issues/8427#issuecomment-1172469631 var frameSize = frame.size(); @@ -134,6 +164,7 @@ public static class RknnDetectionPipeParams { public double confidence; public double nms; public int max_detections; + public NeuralNetworkModelManager.Model model; public RknnDetectionPipeParams() {} } diff --git a/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipeline.java b/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipeline.java index 4919c91512..6f35ec519e 100644 --- a/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipeline.java +++ b/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipeline.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.stream.Collectors; +import org.photonvision.common.configuration.NeuralNetworkModelManager; import org.photonvision.vision.frame.Frame; import org.photonvision.vision.frame.FrameThresholdType; import org.photonvision.vision.opencv.DualOffsetValues; @@ -56,6 +57,7 @@ protected void setPipeParamsImpl() { var params = new RknnDetectionPipeParams(); params.confidence = settings.confidence; params.nms = settings.nms; + params.model = NeuralNetworkModelManager.getInstance().getModel(settings.model); rknnPipe.setParams(params); DualOffsetValues dualOffsetValues = @@ -99,7 +101,6 @@ protected CVPipelineResult process(Frame frame, ObjectDetectionPipelineSettings CVPipeResult> rknnResult = rknnPipe.run(frame.colorImage); sumPipeNanosElapsed += rknnResult.nanosElapsed; - List targetList; var names = rknnPipe.getClassNames(); diff --git a/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipelineSettings.java b/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipelineSettings.java index 118e73b329..ed8fb0de25 100644 --- a/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipelineSettings.java +++ b/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipelineSettings.java @@ -20,6 +20,7 @@ public class ObjectDetectionPipelineSettings extends AdvancedPipelineSettings { public double confidence; public double nms; // non maximal suppression + public String model; public ObjectDetectionPipelineSettings() { super(); From 7f114ca1103ffecd664a5bbcee3ff4aaf5aad7be Mon Sep 17 00:00:00 2001 From: Alextopher Date: Tue, 2 Jul 2024 12:13:28 -0400 Subject: [PATCH 03/13] Add (untested) ui elements --- .../dashboard/tabs/ObjectDetectionTab.vue | 9 +++++++++ .../src/stores/settings/GeneralSettingsStore.ts | 6 ++++-- photon-client/src/types/PipelineTypes.ts | 4 +++- photon-client/src/types/SettingTypes.ts | 1 + .../configuration/NeuralNetworkModelManager.java | 16 +++++++--------- .../configuration/PhotonConfiguration.java | 1 + .../org/photonvision/jni/RknnObjectDetector.java | 6 ++++-- .../vision/pipeline/ObjectDetectionPipeline.java | 7 ++++++- .../ObjectDetectionPipelineSettings.java | 7 +++++++ 9 files changed, 42 insertions(+), 15 deletions(-) diff --git a/photon-client/src/components/dashboard/tabs/ObjectDetectionTab.vue b/photon-client/src/components/dashboard/tabs/ObjectDetectionTab.vue index 25de890ac0..513ef48b37 100644 --- a/photon-client/src/components/dashboard/tabs/ObjectDetectionTab.vue +++ b/photon-client/src/components/dashboard/tabs/ObjectDetectionTab.vue @@ -4,6 +4,7 @@ import { type ActivePipelineSettings, PipelineType } from "@/types/PipelineTypes import PvSlider from "@/components/common/pv-slider.vue"; import { computed, getCurrentInstance } from "vue"; import { useStateStore } from "@/stores/StateStore"; +import { useSettingsStore } from "@/stores/settings/GeneralSettingsStore"; // TODO fix pipeline typing in order to fix this, the store settings call should be able to infer that only valid pipeline type settings are exposed based on pre-checks for the entire config section // Defer reference to store access method @@ -31,6 +32,14 @@ const interactiveCols = computed(() =>