Skip to content

Commit

Permalink
Trying coremltools
Browse files Browse the repository at this point in the history
  • Loading branch information
maekawatoshiki committed Sep 15, 2024
1 parent ba67b92 commit f564f96
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
42 changes: 42 additions & 0 deletions snippets/coreml/mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import time

import coremltools
import torch
import torchvision

from PIL import Image
import numpy as np
from torchvision import transforms


labels = open("../../models/imagenet_classes.txt").readlines()
image = Image.open("../../models/cat.png")

model = torchvision.models.mobilenet_v3_large(pretrained=True)
model.eval()
model = torch.jit.trace(model, torch.zeros(1, 3, 224, 224))

coreml_model = coremltools.convert(
model,
inputs=[coremltools.TensorType(name="input_1", shape=(1, 3, 224, 224))],
outputs=[coremltools.TensorType(name="output_1")],
)

preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
input = preprocess(image)
input = input.unsqueeze(0).numpy()

for i in range(100):
start = time.time()
pred = coreml_model.predict({"input_1": input})["output_1"][0]
print(f"elapsed: {(time.time() - start) * 1000:.2f}ms")
output = np.argsort(pred)[::-1][:5]
output = [labels[i].strip() for i in output]
print(f"top5: {output}")
3 changes: 3 additions & 0 deletions snippets/coreml/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
coremltools==8.0b2
torch==2.3.0
torchvision

0 comments on commit f564f96

Please sign in to comment.