diff --git a/DenseDepth.ipynb b/DenseDepth.ipynb
index 793d418..9b6923f 100644
--- a/DenseDepth.ipynb
+++ b/DenseDepth.ipynb
@@ -3,8 +3,6 @@
"nbformat_minor": 0,
"metadata": {
"colab": {
- "name": "DenseDepth",
- "version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
@@ -23,14 +21,13 @@
"colab_type": "text"
},
"source": [
- ""
+ ""
]
},
{
"cell_type": "code",
"metadata": {
"id": "ahkR4C5dEnR0",
- "colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 139
@@ -38,9 +35,9 @@
"outputId": "491e7243-6dee-4e32-95c9-a817004775cd"
},
"source": [
- "!git clone https://github.com/ialhashim/DenseDepth.git"
+ "!git clone https://github.com/Avi241/DenseDepth.git"
],
- "execution_count": 1,
+ "execution_count": null,
"outputs": [
{
"output_type": "stream",
@@ -61,7 +58,6 @@
"cell_type": "code",
"metadata": {
"id": "fFQgwMlNExak",
- "colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 208
@@ -71,7 +67,7 @@
"source": [
"!wget https://s3-eu-west-1.amazonaws.com/densedepth/nyu.h5 -O ./DenseDepth/nyu.h5"
],
- "execution_count": 2,
+ "execution_count": null,
"outputs": [
{
"output_type": "stream",
@@ -96,7 +92,6 @@
"cell_type": "code",
"metadata": {
"id": "AiJKd6uLE9Gr",
- "colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 468
@@ -106,7 +101,7 @@
"source": [
"!cd DenseDepth; python test.py"
],
- "execution_count": 3,
+ "execution_count": null,
"outputs": [
{
"output_type": "stream",
@@ -146,7 +141,6 @@
"cell_type": "code",
"metadata": {
"id": "HjzqM74-FfyL",
- "colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 627
@@ -160,7 +154,7 @@
"plt.figure(figsize=(20,20))\n",
"plt.imshow( io.imread('./DenseDepth/test.png') )"
],
- "execution_count": 7,
+ "execution_count": null,
"outputs": [
{
"output_type": "execute_result",
@@ -191,20 +185,13 @@
{
"cell_type": "code",
"metadata": {
- "id": "Ra7GEtQBHUMS",
- "colab_type": "code",
- "colab": {}
+ "id": "Ra7GEtQBHUMS"
},
"source": [
- "!cd DenseDepth/PyTorch; python test_pytorch.py\n",
- "from matplotlib import pyplot as plt\n",
- "from skimage import io\n",
- "\n",
- "plt.figure(figsize=(20,20))\n",
- "plt.imshow( io.imread('./DenseDepth/test.png') )"
+ "!cd DenseDepth/PyTorch; python test_pytorch.py --cuda 1"
],
- "execution_count": 0,
+ "execution_count": null,
"outputs": []
}
]
-}
\ No newline at end of file
+}
diff --git a/PyTorch/model.py b/PyTorch/model_pt.py
similarity index 100%
rename from PyTorch/model.py
rename to PyTorch/model_pt.py
diff --git a/PyTorch/load_weight_from_keras.py b/PyTorch/test_pytorch.py
similarity index 74%
rename from PyTorch/load_weight_from_keras.py
rename to PyTorch/test_pytorch.py
index 76a9ced..197bb7b 100644
--- a/PyTorch/load_weight_from_keras.py
+++ b/PyTorch/test_pytorch.py
@@ -19,13 +19,18 @@
from torchvision import models
import torch.nn.functional as F
-from pytorch_model import PTModel
+from model_pt import PTModel
-# Argument Parser
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+#device = 'cpu'
+#Argument Parser
parser = argparse.ArgumentParser(description='High Quality Monocular Depth Estimation via Transfer Learning')
parser.add_argument('--model', default='../nyu.h5', type=str, help='Trained Keras model file.')
parser.add_argument('--input', default='../examples/*.png', type=str, help='Input filename or folder.')
+parser.add_argument('--cuda', default=1, type=int, help='Enable of Disbale Cuda')
args = parser.parse_args()
+if args.cuda==0:
+ device = 'cpu'
# Custom object needed for inference and training
custom_objects = {'BilinearUpSampling2D': BilinearUpSampling2D, 'depth_loss_function': None}
@@ -84,7 +89,14 @@
pytorch_model.load_state_dict(keras_state_dict)
+# pytorch_model = torch.load('depth_3.pth')
pytorch_model.eval()
+#torch.save(pytorch_model,"depth_3.pth")
+pytorch_model.to(device)
+if device.__eq__('cuda'):
+ print("Loaded model to GPU")
+else:
+ print("Loaded model to CPU")
def my_DepthNorm(x, maxDepth):
@@ -94,20 +106,27 @@ def my_predict(model, images, minDepth=10, maxDepth=1000):
with torch.no_grad():
# Compute predictions
- predictions = model(images)
+ predictions = model(images.to(device))
# Put in expected range
- return np.clip(my_DepthNorm(predictions.numpy(), maxDepth=maxDepth), minDepth, maxDepth) / maxDepth
-
+ return np.clip(my_DepthNorm(predictions.cpu().numpy(), maxDepth=maxDepth), minDepth, maxDepth) / maxDepth
+import time
# # Input images
inputs = load_images( glob.glob(args.input) ).astype('float32')
+
pytorch_input = torch.from_numpy(inputs[0,:,:,:]).permute(2,0,1).unsqueeze(0)
-print(pytorch_input.shape)
+
+print("Input Shape = " + str(pytorch_input.shape))
# print('\nLoaded ({0}) images of size {1}.'.format(inputs.shape[0], inputs.shape[1:]))
-# # Compute results
-output = my_predict(pytorch_model,pytorch_input[0,:,:,:].unsqueeze(0))
-print(output.shape)
+# Compute results (When it prdeicts on first it takes some time after that it runs fast you can check with using for loop)
+for i in range(10):
+ tic = time.time()
+ output = my_predict(pytorch_model,pytorch_input[0,:,:,:].unsqueeze(0))
+ toc = time.time()
+ print("Time for test "+str(i)+" "+str(1000*(toc-tic))+" ms")
+
+print("Output Shape = " + str(output.shape))
plt.imshow(output[0,0,:,:])
plt.savefig('test.png')
plt.show()
diff --git a/README.md b/README.md
index 9a9d77c..601b040 100644
--- a/README.md
+++ b/README.md
@@ -5,12 +5,16 @@
Offical Keras (TensorFlow) implementaiton. If you have any questions or need more help with the code, contact the **first author**.
-**[Update]** Added a [Colab notebook](https://github.com/ialhashim/DenseDepth/blob/master/DenseDepth.ipynb) to try the method on the fly.
+**[Update]** Added a [Colab notebook](https://github.com/Avi241/DenseDepth/blob/master/DenseDepth.ipynb) to try the method on the fly.
**[Update]** Experimental TensorFlow 2.0 implementation added.
**[Update]** Experimental PyTorch code added.
+**[Update]** Implemented cuda version of Pytorch.
+
+**[Update]** Added Pytorch model for both Python3 and Python2. Please Find it [here.](https://drive.google.com/file/d/1Wi-w9uyErZwCn_x3DQeILntQgLdml72p/view?usp=share_link)
+
## Results
* KITTI