Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem converting Yolov5n .pt to .tflite #13445

Open
juliermeSilva opened this issue Dec 6, 2024 · 9 comments
Open

Problem converting Yolov5n .pt to .tflite #13445

juliermeSilva opened this issue Dec 6, 2024 · 9 comments
Labels
bug Something isn't working detect Object Detection issues, PR's exports Model exports (ONNX, TensorRT, TFLite, etc.)

Comments

@juliermeSilva
Copy link

juliermeSilva commented Dec 6, 2024

Hello everyone.
I converted a Yolov5 nano network from .pt format to .tflite format. The network only has two classes (Black Ball and Silver Ball). When using the .tflite file, the network stopped detecting Silver Ball. It only detects Black Ball. This behavior does not occur when I use the .pt format.
The code used for training is this:

!pip install ultralytics

from ultralytics import YOLO
model = YOLO("yolov5n.yaml")

import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api")
wandb.login(key=secret_value_0)

results = model.train(
    data='/kaggle/input/dataset/ball-black-silver-yolov5n/data.yaml', 
    epochs=100, 
    patience=30, 
    device='0,1', 
    model='yolov5n.pt', 
    lr0=0.001, 
    freeze=10,
    hsv_h=0.015,      
    hsv_s=0.7,        
    hsv_v=0.4,        
    degrees=10.0,     
    translate=0.1,    
    scale=0.5,        
    shear=2.0,        
    fliplr=0.5,       
    mosaic=1.0,       
    mixup=0.2,        
    warmup_epochs=3.0,
    cos_lr=True,
    weight_decay=0.0001
)

#Export from .pt to .onnx
model.export(format='onnx')

My file data.yaml:

train: /kaggle/input/dataset/ball-black-silver-yolov5n/train/images
val: /kaggle/input/dataset/ball-black-silver-yolov5n/valid/images
test: /kaggle/input/dataset/ball-black-silver-yolov5n/test/images

nc: 2
names: ['Black Ball', 'Silver Ball']


roboflow:
  workspace: testingai-dfp2w
  project: testing-ai-obr
  version: 2
  license: CC BY 4.0
  url: https://universe.roboflow.com/testingai-dfp2w/testing-ai-obr/dataset/2

The yolov5n.pt and yolov5n.onnx formats worked perfectly on a more robust computer (e.g. a university laptop).
I now need to embed the Yolov5n network on a Raspberry Pi 3B.
I then converted the network from the .pt format to the .tflite format.
The conversion code used was this:

!pip install ultralytics

from ultralytics import YOLO
print("Sucess load python libs!")

!cp /kaggle/input/models-obr/yolov5n-obr.pt /kaggle/working
pt_model_path = "/kaggle/working/yolov5n-obr.pt"
model = YOLO(pt_model_path)
print("Sucess load yolov5n-obr!")

model = YOLO(pt_model_path)
print(model.names)

model.export(format="tflite", imgsz=640, optimize=None)

!yolo val task=detect model=/kaggle/working/yolov5n-obr_saved_model/yolov5n-obr_float32.tflite imgsz=640 data=/kaggle/input/dataset/ball-black-silver-yolov5n/data.yaml

Validation results:

Ultralytics 8.3.43 🚀 Python-3.10.14 torch-2.4.0 CPU (Intel Xeon 2.00GHz)
Loading /kaggle/working/yolov5n-obr_saved_model/yolov5n-obr_float32.tflite for TensorFlow Lite inference...
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Setting batch=1 input of shape (1, 3, 640, 640)
Downloading https://ultralytics.com/assets/Arial.ttf to '/root/.config/Ultralytics/Arial.ttf'...
100%|████████████████████████████████████████| 755k/755k [00:00<00:00, 17.2MB/s]
val: Scanning /kaggle/input/dataset/ball-black-silver-yolov5n/valid/labels... 71
val: WARNING ⚠️ Cache directory /kaggle/input/dataset/ball-black-silver-yolov5n/valid is not writeable, cache not saved.
                 Class     Images  Instances      Box(P          R      mAP50  m
                   all        710       1891      0.997       0.98      0.993      0.871
            Black Ball        571        690      0.996      0.974      0.992      0.868
           Silver Ball        673       1201      0.998      0.986      0.994      0.875
Speed: 0.9ms preprocess, 129.1ms inference, 0.0ms loss, 0.8ms postprocess per image
Results saved to runs/detect/val
💡 Learn more at https://docs.ultralytics.com/modes/val

The training code and the conversion code were all run on the Kaggle platform.
Everything was going perfectly until I tried to detect silver balls.
This is my test image (img1.jpg):
img1

This is the code I am using to perform inferences using the yolov5n-obr_float32_v2.tflite network (tflite_black_silver_test_3.py):

import tflite_runtime.interpreter as tflite
import cv2
import numpy as np

id_img = 1

# Paths to the model and label map
MODEL_PATH = "yolov5n-obr_float32_v2.tflite"
LABELMAP_PATH = "labelmap.txt"

def calculate_iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    # Intersection area    
    intersection = max(0, x2 - x1) * max(0, y2 - y1)     
    # Areas of the bounding boxes
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    # Union area        
    union = box1_area + box2_area - intersection     
    return intersection / union if union > 0 else 0


def nms(boxes, confidences, iou_threshold=0.5):    
    indices = np.argsort(confidences)[::-1]  # Sort by confidence (descending)
    keep = []
    while len(indices) > 0:
        current = indices[0]
        keep.append(current)
        others = indices[1:]        
        # Calculate IoU of the current box with the others
        ious = [calculate_iou(boxes[current], boxes[idx]) for idx in others]
        # Keep only boxes with IoU below the threshold
        indices = [idx for i, idx in enumerate(others) if ious[i] <= iou_threshold]    
    return keep


if __name__ == '__main__':
    # Load the label map
    with open(LABELMAP_PATH, "r") as f:
        labels = [line.strip() for line in f.readlines()]

    # Load the model
    interpreter = tflite.Interpreter(model_path=MODEL_PATH)
    interpreter.allocate_tensors()

    # Get tensor details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    # Load a test image with OpenCV
    image_path = "img" + str(id_img) + ".jpg"
    image = cv2.imread(image_path)
    original_height, original_width, _ = image.shape

    # Resize the image to match the model's input size
    input_size = input_details[0]['shape'][1:3]
    resized_image = cv2.resize(image, (input_size[1], input_size[0]))

    # Check the expected type of the input tensor
    if input_details[0]['dtype'] == np.uint8:
        input_data = np.expand_dims(resized_image, axis=0).astype(np.uint8)
    else:
        input_data = np.expand_dims(resized_image.astype(np.float32) / 255.0, axis=0)  # Normalization

    # Perform inference
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()

    # Get the results
    output_data = interpreter.get_tensor(output_details[0]['index'])

    # Reshape for easier manipulation
    output_data = np.squeeze(output_data)  # Remove dimensions of size 1 (e.g., batch size)
    
    # Store predictions for NMS
    all_boxes = []
    all_confidences = []
    all_labels = []

    # Iterate through the predictions
    for i in range(output_data.shape[1]):  # Iterate through 8400 predictions
        x_center, y_center, width, height, confidence, class_id = output_data[:, i]
        # Filter by confidence
        if confidence >= 0.65:  # Adjust confidence threshold as needed
            # Calculate normalized coordinates (x_min, y_min, x_max, y_max)
            x_min = x_center - (width / 2)
            y_min = y_center - (height / 2)
            x_max = x_center + (width / 2)
            y_max = y_center + (height / 2)
            # Convert normalized coordinates to absolute (denormalization)
            y_min_abs = int(y_min * original_height)
            x_min_abs = int(x_min * original_width)
            y_max_abs = int(y_max * original_height)
            x_max_abs = int(x_max * original_width)
            all_boxes.append([x_min_abs, y_min_abs, x_max_abs, y_max_abs])
            all_confidences.append(confidence)
            all_labels.append(int(class_id))

    # Apply Non-Maximum Suppression (NMS)
    nms_indices = nms(all_boxes, all_confidences, iou_threshold=0.5)

    # Display results filtered by NMS
    for idx in nms_indices:
        box = all_boxes[idx]
        confidence = all_confidences[idx]
        label = all_labels[idx]
        cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
        cv2.putText(image, f"{labels[label]}: {confidence:.2f}", (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

    # Save the image with bounding boxes
    output_path = "result.jpg"
    cv2.imwrite(output_path, image)

The resulting image from running the tflite_black_silver_test_3.py code is the one below. Note that the silver balls are not being detected, only the black ball.
result4

Kaggle settings for conversion:
ultralytics-8.3.43
Installed dependencies:

Ultralytics 8.3.43 🚀 Python-3.10.14 torch-2.4.0 CPU (Intel Xeon 2.00GHz)
YOLOv5n summary (fused): 211 layers, 2,182,054 parameters, 0 gradients, 5.8 GFLOPs

PyTorch: starting from '/kaggle/working/yolov5n-obr.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 6, 8400) (4.4 MB)
requirements: Ultralytics requirements ['sng4onnx>=1.0.1', 'onnx_graphsurgeon>=0.3.26', 'onnx2tf>1.17.5,<=1.22.3', 'onnxslim>=0.1.31', 'tflite_support', 'onnxruntime'] not found, attempting AutoUpdate...
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com/
Collecting sng4onnx>=1.0.1
  Downloading sng4onnx-1.0.4-py3-none-any.whl.metadata (4.6 kB)
Collecting onnx_graphsurgeon>=0.3.26
  Downloading onnx_graphsurgeon-0.5.2-py2.py3-none-any.whl.metadata (8.1 kB)
Collecting onnx2tf<=1.22.3,>1.17.5
  Downloading onnx2tf-1.22.3-py3-none-any.whl.metadata (136 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 136.6/136.6 kB 6.3 MB/s eta 0:00:00
Collecting onnxslim>=0.1.31
  Downloading onnxslim-0.1.43-py3-none-any.whl.metadata (4.2 kB)
Collecting tflite_support
  Downloading tflite_support-0.4.4-cp310-cp310-manylinux2014_x86_64.whl.metadata (2.4 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.20.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from onnx_graphsurgeon>=0.3.26) (1.26.4)
Requirement already satisfied: onnx>=1.14.0 in /opt/conda/lib/python3.10/site-packages (from onnx_graphsurgeon>=0.3.26) (1.17.0)
Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from onnxslim>=0.1.31) (1.13.3)
Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from onnxslim>=0.1.31) (21.3)
Requirement already satisfied: absl-py>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from tflite_support) (1.4.0)
Requirement already satisfied: flatbuffers>=2.0 in /opt/conda/lib/python3.10/site-packages (from tflite_support) (24.3.25)
Requirement already satisfied: protobuf<4,>=3.18.0 in /opt/conda/lib/python3.10/site-packages (from tflite_support) (3.20.3)
Collecting sounddevice>=0.4.4 (from tflite_support)
  Downloading sounddevice-0.5.1-py3-none-any.whl.metadata (1.4 kB)
Requirement already satisfied: pybind11>=2.6.0 in /opt/conda/lib/python3.10/site-packages (from tflite_support) (2.13.6)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: CFFI>=1.0 in /opt/conda/lib/python3.10/site-packages (from sounddevice>=0.4.4->tflite_support) (1.16.0)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->onnxslim>=0.1.31) (3.1.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.10/site-packages (from sympy->onnxslim>=0.1.31) (1.3.0)
Requirement already satisfied: pycparser in /opt/conda/lib/python3.10/site-packages (from CFFI>=1.0->sounddevice>=0.4.4->tflite_support) (2.22)
Downloading sng4onnx-1.0.4-py3-none-any.whl (5.9 kB)
Downloading onnx_graphsurgeon-0.5.2-py2.py3-none-any.whl (56 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.4/56.4 kB 230.3 MB/s eta 0:00:00
Downloading onnx2tf-1.22.3-py3-none-any.whl (435 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 435.0/435.0 kB 31.3 MB/s eta 0:00:00
Downloading onnxslim-0.1.43-py3-none-any.whl (142 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 142.9/142.9 kB 300.7 MB/s eta 0:00:00
Downloading tflite_support-0.4.4-cp310-cp310-manylinux2014_x86_64.whl (60.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.8/60.8 MB 198.1 MB/s eta 0:00:00a 0:00:01
Downloading onnxruntime-1.20.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (13.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.3/13.3 MB 209.2 MB/s eta 0:00:00 0:00:01
Downloading sounddevice-0.5.1-py3-none-any.whl (32 kB)
Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.0/46.0 kB 248.4 MB/s eta 0:00:00
Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 86.8/86.8 kB 282.8 MB/s eta 0:00:00
Installing collected packages: sng4onnx, onnx2tf, humanfriendly, sounddevice, onnxslim, onnx_graphsurgeon, coloredlogs, tflite_support, onnxruntime
Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnx2tf-1.22.3 onnx_graphsurgeon-0.5.2 onnxruntime-1.20.1 onnxslim-0.1.43 sng4onnx-1.0.4 sounddevice-0.5.1 tflite_support-0.4.4

requirements: AutoUpdate success ✅ 13.3s, installed 6 packages: ['sng4onnx>=1.0.1', 'onnx_graphsurgeon>=0.3.26', 'onnx2tf>1.17.5,<=1.22.3', 'onnxslim>=0.1.31', 'tflite_support', 'onnxruntime']

Settings on the Raspberry PI 3B:
Python: 3.11.2
tflite-runtime 2.14.0

What could I be doing wrong?

Why does the network work perfectly in the .pt format but this problem occurs in the .tflite format?

I appreciate any help!

Originally posted by @juliermeSilva in #13444

@UltralyticsAssistant UltralyticsAssistant added bug Something isn't working detect Object Detection issues, PR's exports Model exports (ONNX, TensorRT, TFLite, etc.) labels Dec 6, 2024
@UltralyticsAssistant
Copy link
Member

👋 Hello @juliermeSilva, thank you for your interest in YOLOv5 🚀! This is an automated response to help guide you, and an Ultralytics engineer will assist you further shortly.

If this is a 🐛 Bug Report, it would be helpful to provide a minimum reproducible example (MRE). In this case, please ensure that any custom code, dataset details, or additional information impacting the .pt to .tflite conversion is included. The information you've provided so far is quite detailed, which is great! Ensuring consistency in the training, exporting, and inference pipelines is critical for debugging discrepancies.

For general issues with custom training and inference ❓, please make sure:

  • Your dataset is labeled accurately and corresponds well to your data.yaml.
  • You're performing proper validation of the .tflite model after export.
  • You're adhering to best practices for quantization (if applicable) or export formats when converting models for edge devices like Raspberry Pi.

Regarding your Python environment:

  • Please confirm you are using Python>=3.8.0, and that all dependencies, such as PyTorch, TensorFlow Lite runtime, and OpenCV, are up to date.

If you're utilizing a custom inference pipeline (like tflite_runtime.interpreter in your case), ensure that input preprocessing (e.g., resizing, normalization) matches the expected format of the converted model.

Lastly, if the .pt to .tflite conversion step was performed with specific optimizations, such as quantization, simplifying your model/export settings to ensure compatibility may help narrow down the issue.

Let us know if further clarification is needed, and thank you for sharing your detailed report! 😊

@juliermeSilva
Copy link
Author

Sending extra data that I forgot to send in the previous post.
Image generated by the NETRON application from the yolov5n-obr.onnx file generated by the conversion process using Ultralytics (code in the previous post):
Captura de tela de 2024-12-06 19-57-17

Image generated by the NETRON application from the yolov5n-obr.tflite file generated by the conversion process using Ultralytics (code in the previous post):
Captura de tela de 2024-12-06 19-57-48

It is possible to observe using NETRON that in the .tflite file no information about the classes that the model manipulates is visible. However, the .onnx file explicitly displays information about the model classes in Metadata.

I do not know if this is normal behavior for a .tflite file or if there really is a problem in the conversion from Yolov5n to .tflite using Ultralytics.

@juliermeSilva
Copy link
Author

Hello everyone.
I ran some more tests.
I converted versions of the YOLO 5, 8 and 11 models already trained by the Ultralytics team. I tested different model sizes (Nano, Small, Medium, Large, XLarge).
Always using this command:

#model = YOLO("yolov5n.pt") # YOLOV5 nano: Model pre-trained on COCO base
#model = YOLO("yolov5s.pt") # YOLOV5 small: Model pre-trained on COCO base
#model = YOLO("yolov5m.pt") # YOLOV5 small: Model pre-trained on COCO base
#model = YOLO("yolov5l.pt") # YOLOV5 small: Model pre-trained on COCO base
#model = YOLO("yolov5x.pt") # YOLOV5 small: Model pre-trained on COCO base

#model = YOLO("yolov8n.pt") # YOLOV8s small: Model pre-trained on COCO base
model = YOLO("yolov8s.pt") # YOLOV8s small: Model pre-trained on COCO base

model.export(format="tflite", imgsz=640, optimize=None)

The result remained the same as reported in previous posts, that is, only the first class was detected by the models. The others were ignored. I don't know why this is happening when I use the tflite_runtime library to perform inferences.

I then used the Ultralytics library to load the yolov5n-obr.tflite model and test inferences, describing the results in the terminal. Note that the model behaves as expected, making correct inferences for all test images.

Code used:

# Inference yolov5n-obr.tflite
from ultralytics import YOLO
model_path = "/kaggle/working/yolov5n-obr_saved_model/yolov5n-obr_float32.tflite"
image_path = "/kaggle/input/test-black-silver-ball/img1.jpg"
#image_path = "/kaggle/input/test-black-silver-ball/img2.jpg"
#image_path = "/kaggle/input/test-black-silver-ball/img3.jpg"
#image_path = "/kaggle/input/test-black-silver-ball/img4.jpg"

# Load the YOLO model in TFLite format
model = YOLO(model_path)

# Perform the inference
results = model(image_path)

# Check if there was any inference
detected_classes = []
for result in results:  
    if (result.boxes is not None):  
        for box in result.boxes:  
            cls_id = int(box.cls[0])  
            conf = box.conf[0]  
            if (conf > 0.5):
                detected_classes.append(result.names[cls_id])

# Display detected ball types
print("Types of balls detected in the image:")
if (detected_classes):
    for detected_class in detected_classes:
        print(f"- {detected_class}")
else:
    print("No ball detected.")

Example result (for img1.jpg):

Loading /kaggle/working/yolov5n-obr_saved_model/yolov5n-obr_float32.tflite for TensorFlow Lite inference...

image 1/1 /kaggle/input/test-black-silver-ball/img1.jpg: 640x640 1 Black Ball, 2 Silver Balls, 211.4ms
Speed: 3.2ms preprocess, 211.4ms inference, 1.3ms postprocess per image at shape (1, 3, 640, 640)
Types of balls detected in the image:
- Black Ball
- Silver Ball
- Silver Ball

I hope this information helps in the process of figuring out why the tflite_runtime library is not able to correctly infer all the classes in which the model was trained.

I am available if any further information is needed to find the solution to this problem.

@pderrenger
Copy link
Member

@juliermeSilva thank you for your detailed testing and additional information. Based on your results, it appears the issue lies with how the tflite_runtime library processes the model's outputs, rather than with the YOLOv5/TFLite export or the model itself. When using the Ultralytics library for inference, the detections are accurate, confirming that the exported TFLite model is functioning as intended.

Here are some potential areas to explore within the tflite_runtime setup:

  1. Model Output Post-Processing: Ensure that tflite_runtime processes the model outputs correctly. Specifically, validate how the confidences and class predictions are being handled. From your inference code, verify the model outputs match the expected TensorFlow Lite format. Differences in axis indexing or output dimensions may lead to certain classes being ignored.

  2. Normalization and Confidence Thresholds: Double-check that input normalizations and confidence thresholds in tflite_runtime are consistent with Ultralytics' implementation. Mismatched preprocessing or thresholds may cause issues interpreting predictions.

  3. Class Mapping: Ensure class IDs are being translated correctly in the tflite_runtime inference results. This might account for certain classes not being interpreted properly.

For further diagnosis, consider logging and comparing the raw output tensors from both tflite_runtime and the Ultralytics library to identify differences. Additionally, if the issue persists, you can simulate TFLite inference using TensorFlow Lite's Interpreter rather than tflite_runtime, which might provide clearer debugging outputs.

Let me know if you need help debugging the raw TFLite outputs!

@juliermeSilva
Copy link
Author

Hello pderrenger, I appreciate your attention and your initiative to help me solve this problem regarding the use of yolov5n converted to .tflite.
I ran new tests, now converting the Yolov5n model to the .ncnn format.
All models work normally when tested using the Ultralytics library.
The problem must be in the way to manipulate the neural model using the tflite-runtime library.
I'm missing something that I can't understand.

In these new tests, the yolov5n-obr.tflite model manipulation algorithm, using the tflite-runtime library, displayed the output data structure of the neural model.
I have put all the tests and their respective results carried out so far in a notebook available on the Kaggle platform: https://www.kaggle.com/code/juliermesilva/test-ultralytics-convert

Any help is welcome.

@pderrenger
Copy link
Member

Thank you for the detailed update and extensive testing. Based on your findings, it does seem that the issue lies in post-processing the TFLite model outputs using the tflite-runtime library. Since the exported TFLite model works correctly with the Ultralytics library, the model itself and the export process are validated as correct.

To address this, double-check how the outputs are parsed and interpreted in your tflite-runtime implementation. Specifically, ensure that class IDs, confidences, and bounding box coordinates are being correctly extracted and processed. Comparing the raw outputs from both tflite-runtime and Ultralytics inference may help identify discrepancies. If necessary, consult the TensorFlow Lite documentation to confirm you're correctly handling the model's output structure.

Let me know if you need further guidance!

@juliermeSilva
Copy link
Author

Hello, pderrenger.
I have been checking the model's output tensor in .tflite format. From the research I have done on the internet, in the documentation and in the examples I have found, the output tensor format does not seem strange.

I obtain the tensor with the inference data from information in the output_details, for example, the value 476 is the reference index for the tensor containing the inference data and has the following format [1, 6, 8400]. Once you know the index for the output tensor, getting it is simple, just run this command:

output_data = interpreter.get_tensor(output_details[0]['index'])[0]

The format for the output tensor is exactly as described in output_detalis:

output_data:
[[ 0.013904 0.022383 0.043986 ... 0.7974 0.91485 0.94513]
[ 0.017683 0.011316 0.0066985 ... 0.87934 0.86157 0.8148]
[ 0.028001 0.045665 0.085364 ... 0.40485 0.17223 0.10845]
[ 0.057635 0.026243 0.014182 ... 0.24054 0.27528 0.36619]
[ 7.8421e-07 1.7045e-07 2.8345e-08 ... 1.237e-06 3.8692e-06 1.2041e-06]
[ 2.8902e-06 7.1576e-07 2.8083e-07 ... 1.793e-06 1.1865e-05 3.9144e-06]]

There are 8400 columns, containing 6 rows. Each column corresponds to the 6 values ​​inferred by the model, that is,

x_center,
y_center,
width,
height,
confidence,
class

I have already checked the official documentation and did not notice any anomalies.

I do not believe that the problem could be in the reading of the data structures.

Perhaps related to some scale adjustment that I am not able to visualize (my speculation).

If there is something wrong, I still cannot see it.

I inserted a new test in the Kaggle notebook, this time related to reading images using the PIL library. I thought it could be something related to the input data. I also tried to simplify the code of this test as much as possible to focus on the format of the data structure.
https://www.kaggle.com/code/juliermesilva/test-ultralytics-convert

In order not to pollute this issue with huge results, I think it is better to continue maintaining large results on Kaggle.

@juliermeSilva
Copy link
Author

juliermeSilva commented Dec 16, 2024

I discovered the problem!
After printing and drawing the values ​​of the output tensor returned by tflite_runtime, I understood the pattern that this library works with.
In the details of the output tensor we can see the shape of the output tensor which is [1, 6, 8400], which is confirmed when the output tensor data is printed.

Output tensor details:
 [{'name': 'Identity', 'index': 476, 'shape': array([   1,    6, 8400], dtype=int32), 'shape_signature': array([   1,    6, 8400], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

Example output tensor data:

output_data:
[[ 0.013904 0.022383 0.043986 ... 0.7974 0.91485 0.94513]
[ 0.017683 0.011316 0.0066985 ... 0.87934 0.86157 0.8148]
[ 0.028001 0.045665 0.085364 ... 0.40485 0.17223 0.10845]
[ 0.057635 0.026243 0.014182 ... 0.24054 0.27528 0.36619]
[ 7.8421e-07 1.7045e-07 2.8345e-08 ... 1.237e-06 3.8692e-06 1.2041e-06]
[ 2.8902e-06 7.1576e-07 2.8083e-07 ... 1.793e-06 1.1865e-05 3.9144e-06]]

In the shape definition, the value 8400 is the number of columns, which represents the amount of inferences performed by the model. The value 6 represents the number of rows in the output tensor, and the meaning of each row has been the real problem so far. I was interpreting the following values ​​as being (reading from top to bottom):

line 1: x_center
line 2: y_center
line 3: width
line 4: height
line 5: confidence
line 6: class

Therefore, I always read this data according to this code, guided by the way YOLO normally returns its data:

output_data = interpreter.get_tensor(output_details[0]['index'])[0]
for i in range(output_data.shape[1]): # Iterate through 8400 predictions
        x_center, y_center, width, height, confidence, class_id = output_data[:, i]

But the format of the output tensor from .tflite is different. There is no specific line for class values ​​(in my case 0 for the Black ball class and 1 for the Silver ball class). The last two lines of the output tensor represent the confidence values ​​for detecting objects in class 0 and the confidence values ​​for detecting objects in class 1. The first four lines are related to the coordinates of the bouding box.
So the format of the output tensor is:

line 1: x_center
line 2: y_center
line 3: width
line 4: height
line 5: confidence_class0
line 6: confidence_class1

Correct reading of the tensor output of the .tflite model, using the tflite_runtime library:

output_data = interpreter.get_tensor(output_details[0]['index'])[0]
for i in range(output_data.shape[1]): # Iterate through 8400 predictions
        x_center, y_center, width, height, confidence_class0, confidence_class1 = output_data[:, i]

Image of the result after correcting the reading of the output tensor:
Captura de tela de 2024-12-16 08-05-24

To confirm my understanding, I converted yolov8n.pt to yolov8n.tflite. This model was trained by Ultralytics on the COCO database. This database has 80 classes. They are:

{0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}

Note that the output tensor of the yolov8n.tflite model has 84 lines 'shape': array([ 1, 84, 8400].

Output tensor details:
[{'name': 'Identity', 'index': 407, 'shape': array([ 1, 84, 8400], dtype=int32), 'shape_signature': array([ 1, 84, 8400], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

The first 4 lines are related to the coordinates of the boudingbox and the other 80 lines are related to the confidence for the detections of the 80 classes.

line 1: x_center
line 2: y_center
line 3: width
line 4: height
line 5: confidence_class0
line 6: confidence_class1
.
.
.
line 84: confidence_class84

@pderrenger
Copy link
Member

Thank you for sharing your findings and detailed explanation! Your discovery regarding the structure of the TFLite output tensor is correct and aligns with how YOLO models export outputs in TFLite format. The last rows of the tensor correspond to class-specific confidence scores, rather than a single "class ID" row. This difference in format compared to PyTorch outputs is a common source of confusion.

It's great to see that you've resolved the issue by correctly interpreting the output tensor, and the results now match your expectations. For others encountering similar issues, referring to the TFLite model's output details (e.g., shape and structure) is essential for correctly parsing the results.

If you need further assistance or have additional questions, feel free to ask. Kudos to you for the thorough debugging and sharing your solution with the community—it’s invaluable for others working on similar tasks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working detect Object Detection issues, PR's exports Model exports (ONNX, TensorRT, TFLite, etc.)
Projects
None yet
Development

No branches or pull requests

3 participants