-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #247 from roboflow/add-classification-keypoint-video
Add support for classification and keypoint video processing
- Loading branch information
Showing
8 changed files
with
195 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import base64 | ||
import io | ||
import json | ||
import os | ||
import urllib | ||
|
||
import requests | ||
from PIL import Image | ||
|
||
from roboflow.config import CLASSIFICATION_MODEL | ||
from roboflow.models.inference import InferenceModel | ||
from roboflow.util.image_utils import check_image_url | ||
from roboflow.util.prediction import PredictionGroup | ||
|
||
|
||
class KeypointDetectionModel(InferenceModel): | ||
""" | ||
Run inference on a classification model hosted on Roboflow or served through | ||
Roboflow Inference. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
api_key: str, | ||
id: str, | ||
name: str = None, | ||
version: int = None, | ||
local: bool = False, | ||
): | ||
""" | ||
Create a ClassificationModel object through which you can run inference. | ||
Args: | ||
api_key (str): private roboflow api key | ||
id (str): the workspace/project id | ||
name (str): is the name of the project | ||
version (int): version number | ||
local (bool): whether the image is local or hosted | ||
colors (dict): colors to use for the image | ||
preprocessing (dict): preprocessing to use for the image | ||
Returns: | ||
ClassificationModel Object | ||
""" | ||
# Instantiate different API URL parameters | ||
super(KeypointDetectionModel, self).__init__(api_key, id, version=version) | ||
self.__api_key = api_key | ||
self.id = id | ||
self.name = name | ||
self.version = version | ||
self.base_url = "https://detect.roboflow.com/" | ||
|
||
if self.name is not None and version is not None: | ||
self.__generate_url() | ||
|
||
if local: | ||
print("initalizing local keypoint detection model hosted at :" + local) | ||
self.base_url = local | ||
|
||
def predict(self, image_path, hosted=False): | ||
""" | ||
Run inference on an image. | ||
Args: | ||
image_path (str): path to the image you'd like to perform prediction on | ||
hosted (bool): whether the image you're providing is hosted on Roboflow | ||
Returns: | ||
PredictionGroup Object | ||
Example: | ||
>>> import roboflow | ||
>>> rf = roboflow.Roboflow(api_key="") | ||
>>> project = rf.workspace().project("PROJECT_ID") | ||
>>> model = project.version("1").model | ||
>>> prediction = model.predict("YOUR_IMAGE.jpg") | ||
""" | ||
self.__generate_url() | ||
self.__exception_check(image_path_check=image_path) | ||
# If image is local image | ||
if not hosted: | ||
# Open Image in RGB Format | ||
image = Image.open(image_path).convert("RGB") | ||
# Create buffer | ||
buffered = io.BytesIO() | ||
image.save(buffered, quality=90, format="JPEG") | ||
img_dims = image.size | ||
# Base64 encode image | ||
img_str = base64.b64encode(buffered.getvalue()) | ||
img_str = img_str.decode("ascii") | ||
# Post to API and return response | ||
resp = requests.post( | ||
self.api_url, | ||
data=img_str, | ||
headers={"Content-Type": "application/x-www-form-urlencoded"}, | ||
) | ||
else: | ||
# Create API URL for hosted image (slightly different) | ||
self.api_url += "&image=" + urllib.parse.quote_plus(image_path) | ||
# POST to the API | ||
resp = requests.post(self.api_url) | ||
img_dims = {"width": "0", "height": "0"} | ||
|
||
if resp.status_code != 200: | ||
raise Exception(resp.text) | ||
|
||
return PredictionGroup.create_prediction_group( | ||
resp.json(), | ||
image_dims=img_dims, | ||
image_path=image_path, | ||
prediction_type=CLASSIFICATION_MODEL, | ||
colors=self.colors, | ||
) | ||
|
||
def load_model(self, name, version): | ||
""" | ||
Load a model. | ||
Args: | ||
name (str): is the name of the model you'd like to load | ||
version (int): version number | ||
""" | ||
# Load model based on user defined characteristics | ||
self.name = name | ||
self.version = version | ||
self.__generate_url() | ||
|
||
def __generate_url(self): | ||
""" | ||
Generate a Roboflow API URL on which to run inference. | ||
Returns: | ||
url (str): the url on which to run inference | ||
""" | ||
|
||
# Generates URL based on all parameters | ||
splitted = self.id.rsplit("/") | ||
without_workspace = splitted[1] | ||
version = self.version | ||
if not version and len(splitted) > 2: | ||
version = splitted[2] | ||
|
||
self.api_url = "".join( | ||
[ | ||
self.base_url + without_workspace + "/" + str(version), | ||
"?api_key=" + self.__api_key, | ||
"&name=YOUR_IMAGE.jpg", | ||
] | ||
) | ||
|
||
def __exception_check(self, image_path_check=None): | ||
""" | ||
Check to see if an image exists. | ||
Args: | ||
image_path_check (str): path to the image to check | ||
Raises: | ||
Exception: if image does not exist | ||
""" | ||
# Checks if image exists | ||
if image_path_check is not None: | ||
if not os.path.exists(image_path_check) and not check_image_url(image_path_check): | ||
raise Exception("Image does not exist at " + image_path_check + "!") | ||
|
||
def __str__(self): | ||
""" | ||
String representation of classification object | ||
""" | ||
json_value = { | ||
"name": self.name, | ||
"version": self.version, | ||
"base_url": self.base_url, | ||
} | ||
|
||
return json.dumps(json_value, indent=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters