Skip to content

Commit

Permalink
Merge pull request #247 from roboflow/add-classification-keypoint-video
Browse files Browse the repository at this point in the history
Add support for classification and keypoint video processing
  • Loading branch information
capjamesg authored Apr 11, 2024
2 parents bda8d5d + c1a7700 commit 70dc66c
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 8 deletions.
2 changes: 1 addition & 1 deletion roboflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from roboflow.models import CLIPModel, GazeModel # noqa: F401
from roboflow.util.general import write_line

__version__ = "1.1.26"
__version__ = "1.1.27"


def check_key(api_key, model, notebook, num_retries=0):
Expand Down
1 change: 1 addition & 0 deletions roboflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def get_conditional_configuration_variable(key, default):
TYPE_OBJECT_DETECTION = "object-detection"
TYPE_INSTANCE_SEGMENTATION = "instance-segmentation"
TYPE_SEMANTIC_SEGMENTATION = "semantic-segmentation"
TYPE_KEYPOINT_DETECTION = "keypoint-detection"

DEFAULT_BATCH_NAME = "Pip Package Upload"
DEFAULT_JOB_NAME = "Annotated via API"
Expand Down
4 changes: 4 additions & 0 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
DEMO_KEYS,
TYPE_CLASSICATION,
TYPE_INSTANCE_SEGMENTATION,
TYPE_KEYPOINT_DETECTION,
TYPE_OBJECT_DETECTION,
TYPE_SEMANTIC_SEGMENTATION,
UNIVERSE_URL,
)
from roboflow.core.dataset import Dataset
from roboflow.models.classification import ClassificationModel
from roboflow.models.instance_segmentation import InstanceSegmentationModel
from roboflow.models.keypoint_detection import KeypointDetectionModel
from roboflow.models.object_detection import ObjectDetectionModel
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
from roboflow.util.annotations import amend_data_yaml
Expand Down Expand Up @@ -124,6 +126,8 @@ def __init__(
)
elif self.type == TYPE_SEMANTIC_SEGMENTATION:
self.model = SemanticSegmentationModel(self.__api_key, self.id)
elif self.type == TYPE_KEYPOINT_DETECTION:
self.model = KeypointDetectionModel(self.__api_key, self.id, version=version_without_workspace)
else:
self.model = None

Expand Down
4 changes: 3 additions & 1 deletion roboflow/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from PIL import Image

from roboflow.config import CLASSIFICATION_MODEL
from roboflow.models.inference import InferenceModel
from roboflow.util.image_utils import check_image_url
from roboflow.util.prediction import PredictionGroup


class ClassificationModel:
class ClassificationModel(InferenceModel):
"""
Run inference on a classification model hosted on Roboflow or served through
Roboflow Inference.
Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(
ClassificationModel Object
"""
# Instantiate different API URL parameters
super(ClassificationModel, self).__init__(api_key, id, version=version)
self.__api_key = api_key
self.id = id
self.name = name
Expand Down
2 changes: 2 additions & 0 deletions roboflow/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def predict_video(
self.type = "gaze-detection"
elif model_class == "CLIPModel":
self.type = "clip-embed-image"
elif model_class == "KeypointDetectionModel":
self.type = "keypoint-detection"
else:
raise Exception("Model type not supported for video inference.")

Expand Down
180 changes: 180 additions & 0 deletions roboflow/models/keypoint_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import base64
import io
import json
import os
import urllib

import requests
from PIL import Image

from roboflow.config import CLASSIFICATION_MODEL
from roboflow.models.inference import InferenceModel
from roboflow.util.image_utils import check_image_url
from roboflow.util.prediction import PredictionGroup


class KeypointDetectionModel(InferenceModel):
"""
Run inference on a classification model hosted on Roboflow or served through
Roboflow Inference.
"""

def __init__(
self,
api_key: str,
id: str,
name: str = None,
version: int = None,
local: bool = False,
):
"""
Create a ClassificationModel object through which you can run inference.
Args:
api_key (str): private roboflow api key
id (str): the workspace/project id
name (str): is the name of the project
version (int): version number
local (bool): whether the image is local or hosted
colors (dict): colors to use for the image
preprocessing (dict): preprocessing to use for the image
Returns:
ClassificationModel Object
"""
# Instantiate different API URL parameters
super(KeypointDetectionModel, self).__init__(api_key, id, version=version)
self.__api_key = api_key
self.id = id
self.name = name
self.version = version
self.base_url = "https://detect.roboflow.com/"

if self.name is not None and version is not None:
self.__generate_url()

if local:
print("initalizing local keypoint detection model hosted at :" + local)
self.base_url = local

def predict(self, image_path, hosted=False):
"""
Run inference on an image.
Args:
image_path (str): path to the image you'd like to perform prediction on
hosted (bool): whether the image you're providing is hosted on Roboflow
Returns:
PredictionGroup Object
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("YOUR_IMAGE.jpg")
"""
self.__generate_url()
self.__exception_check(image_path_check=image_path)
# If image is local image
if not hosted:
# Open Image in RGB Format
image = Image.open(image_path).convert("RGB")
# Create buffer
buffered = io.BytesIO()
image.save(buffered, quality=90, format="JPEG")
img_dims = image.size
# Base64 encode image
img_str = base64.b64encode(buffered.getvalue())
img_str = img_str.decode("ascii")
# Post to API and return response
resp = requests.post(
self.api_url,
data=img_str,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
else:
# Create API URL for hosted image (slightly different)
self.api_url += "&image=" + urllib.parse.quote_plus(image_path)
# POST to the API
resp = requests.post(self.api_url)
img_dims = {"width": "0", "height": "0"}

if resp.status_code != 200:
raise Exception(resp.text)

return PredictionGroup.create_prediction_group(
resp.json(),
image_dims=img_dims,
image_path=image_path,
prediction_type=CLASSIFICATION_MODEL,
colors=self.colors,
)

def load_model(self, name, version):
"""
Load a model.
Args:
name (str): is the name of the model you'd like to load
version (int): version number
"""
# Load model based on user defined characteristics
self.name = name
self.version = version
self.__generate_url()

def __generate_url(self):
"""
Generate a Roboflow API URL on which to run inference.
Returns:
url (str): the url on which to run inference
"""

# Generates URL based on all parameters
splitted = self.id.rsplit("/")
without_workspace = splitted[1]
version = self.version
if not version and len(splitted) > 2:
version = splitted[2]

self.api_url = "".join(
[
self.base_url + without_workspace + "/" + str(version),
"?api_key=" + self.__api_key,
"&name=YOUR_IMAGE.jpg",
]
)

def __exception_check(self, image_path_check=None):
"""
Check to see if an image exists.
Args:
image_path_check (str): path to the image to check
Raises:
Exception: if image does not exist
"""
# Checks if image exists
if image_path_check is not None:
if not os.path.exists(image_path_check) and not check_image_url(image_path_check):
raise Exception("Image does not exist at " + image_path_check + "!")

def __str__(self):
"""
String representation of classification object
"""
json_value = {
"name": self.name,
"version": self.version,
"base_url": self.base_url,
}

return json.dumps(json_value, indent=2)
8 changes: 2 additions & 6 deletions roboflow/models/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
from roboflow.config import API_URL
from roboflow.models.inference import InferenceModel

SUPPORTED_ROBOFLOW_MODELS = [
"object-detection",
"classification",
"instance-segmentation",
]
SUPPORTED_ROBOFLOW_MODELS = ["object-detection", "classification", "instance-segmentation", "keypoint-detection"]

SUPPORTED_ADDITIONAL_MODELS = {
"clip": {
Expand Down Expand Up @@ -97,7 +93,7 @@ def predict(

for model in additional_models:
if model not in SUPPORTED_ADDITIONAL_MODELS:
raise Exception(f"Model {model} is no t supported for video inference.")
raise Exception(f"Model {model} is not supported for video inference.")

if inference_type not in SUPPORTED_ROBOFLOW_MODELS:
raise Exception(f"Model {inference_type} is not supported for video inference.")
Expand Down
2 changes: 2 additions & 0 deletions roboflow/roboflowpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from roboflow.config import APP_URL, get_conditional_configuration_variable, load_roboflow_api_key
from roboflow.models.classification import ClassificationModel
from roboflow.models.instance_segmentation import InstanceSegmentationModel
from roboflow.models.keypoint_detection import KeypointDetectionModel
from roboflow.models.object_detection import ObjectDetectionModel
from roboflow.models.semantic_segmentation import SemanticSegmentationModel

Expand Down Expand Up @@ -133,6 +134,7 @@ def infer(args):
"classification": ClassificationModel,
"instance-segmentation": InstanceSegmentationModel,
"semantic-segmentation": SemanticSegmentationModel,
"keypoint-detection": KeypointDetectionModel,
}[projectType]
model = modelClass(api_key, project_url)
kwargs = {}
Expand Down

0 comments on commit 70dc66c

Please sign in to comment.