From 7071aef71b5a7a51d09d7112048779484a396b7d Mon Sep 17 00:00:00 2001 From: Chuck Cho Date: Tue, 29 Mar 2016 12:12:57 -0400 Subject: [PATCH] Minor changes: delete TODO.txt; ran/fixed lint; better formatting, etc. --- TODO(chuck).txt | 18 ------ include/caffe/blob.hpp | 11 ++-- include/caffe/layers/video_data_layer.hpp | 3 +- include/caffe/util/im2col.hpp | 4 +- include/caffe/util/io.hpp | 1 + scripts/cpp_lint.py | 6 +- src/caffe/blob.cpp | 69 ++++++++++++--------- src/caffe/layers/base_conv_layer.cpp | 36 +++++------ src/caffe/layers/cudnn_conv_layer.cpp | 4 -- src/caffe/layers/im2col_layer.cpp | 7 --- src/caffe/layers/pooling_layer.cpp | 1 - src/caffe/test/test_convolution_layer.cpp | 40 +++--------- src/caffe/test/test_deconvolution_layer.cpp | 6 -- src/caffe/test/test_embed_layer.cpp | 2 +- src/caffe/test/test_io.cpp | 57 ++++++++--------- src/caffe/test/test_pooling_layer.cpp | 4 -- src/caffe/test/test_video_data_layer.cpp | 14 ++--- src/caffe/util/im2col.cu | 15 ----- src/caffe/util/io.cpp | 38 ++++++------ 19 files changed, 133 insertions(+), 203 deletions(-) delete mode 100644 TODO(chuck).txt diff --git a/TODO(chuck).txt b/TODO(chuck).txt deleted file mode 100644 index a0c05fa9..00000000 --- a/TODO(chuck).txt +++ /dev/null @@ -1,18 +0,0 @@ -Just to remind myself (Chuck) some of the ongoing/undone work for Caffe+C3D merge -- make sure to do regression test with previous 4D blobs without length (or depth) - (n*c*h*w) -- DONE -- python wrapper -- check crop vs resize -- DONE -- misc tools: e.g. computer volume mean, extract c3d features -- check (N, C, W, H) vs (N, C, L, W, H) -- DONE -- params _h and _w don't work as intended yet -- DONE -- CUDNN supports 2dim convolution only -- clean wrapper for applying CPU / - non-CUDNN CUDA / CUDNN for 2dim/ndim conv -- DONE -- Make sure changes in these are safe: blob.cpp,hpp, im2col_layer.cpp - -WORK TO DO -- make sure all tests pass -- DONE -- merge into dextro-research/Vision -- IN REVIEW -- python -- regression test (older models should work) -- ONGOING -- when everything's done, remove print-out's, debug msgs diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index 54e5c2ab..f3ff7738 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -138,13 +138,14 @@ class Blob { inline int channels() const { return LegacyShape(1); } /// @brief Deprecated legacy shape accessor length: use shape(2) instead. inline int length() const { return (num_axes() == 5) ? LegacyShape(2) : 1; } - //inline int length() const { return LegacyShape(2); } /// @brief Deprecated legacy shape accessor height: use shape(3) instead. - inline int height() const { return (num_axes() == 5) ? LegacyShape(3) : LegacyShape(2); } - //inline int height() const { return LegacyShape(3); } + inline int height() const { + return (num_axes() == 5) ? LegacyShape(3) : LegacyShape(2); + } /// @brief Deprecated legacy shape accessor width: use shape(4) instead. - inline int width() const { return (num_axes() == 5) ? LegacyShape(4) : LegacyShape(3); } - //inline int width() const { return LegacyShape(4); } + inline int width() const { + return (num_axes() == 5) ? LegacyShape(4) : LegacyShape(3); + } inline int LegacyShape(int index) const { CHECK_LE(num_axes(), 5) << "Cannot use legacy accessors on Blobs with > 5 axes."; diff --git a/include/caffe/layers/video_data_layer.hpp b/include/caffe/layers/video_data_layer.hpp index 7fbbc5ea..4c4e62d1 100644 --- a/include/caffe/layers/video_data_layer.hpp +++ b/include/caffe/layers/video_data_layer.hpp @@ -15,8 +15,7 @@ // an extension the std::pair which used to store image filename and // its label (int). now, a frame number associated with the video filename // is needed (second param) to fully represent a video segment -struct triplet -{ +struct triplet { std::string first; int second, third; }; diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp index 54712cc6..f372c589 100644 --- a/include/caffe/util/im2col.hpp +++ b/include/caffe/util/im2col.hpp @@ -7,7 +7,7 @@ template void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - const int* dilation, Dtype* data_col, const bool forced_3d=false); + const int* dilation, Dtype* data_col, const bool forced_3d = false); template void im2col_cpu(const Dtype* data_im, const int channels, @@ -20,7 +20,7 @@ template void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - const int* dilation, Dtype* data_im, const bool forced_3d=false); + const int* dilation, Dtype* data_im, const bool forced_3d = false); template void col2im_cpu(const Dtype* data_col, const int channels, diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index 1987805d..dc0b4cf3 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -5,6 +5,7 @@ #include #include // NOLINT(readability/streams) #include +#include #include "google/protobuf/message.h" diff --git a/scripts/cpp_lint.py b/scripts/cpp_lint.py index f750489f..57cb5303 100755 --- a/scripts/cpp_lint.py +++ b/scripts/cpp_lint.py @@ -1595,10 +1595,10 @@ def CheckCaffeAlternatives(filename, clean_lines, linenum, error): def CheckCaffeDataLayerSetUp(filename, clean_lines, linenum, error): """Except the base classes, Caffe DataLayer should define DataLayerSetUp instead of LayerSetUp. - + The base DataLayers define common SetUp steps, the subclasses should not override them. - + Args: filename: The name of the current file. clean_lines: A CleansedLines instance containing the file. @@ -1610,6 +1610,7 @@ def CheckCaffeDataLayerSetUp(filename, clean_lines, linenum, error): if ix >= 0 and ( line.find('void DataLayer::LayerSetUp') != -1 or line.find('void ImageDataLayer::LayerSetUp') != -1 or + line.find('void VideoDataLayer::LayerSetUp') != -1 or line.find('void MemoryDataLayer::LayerSetUp') != -1 or line.find('void WindowDataLayer::LayerSetUp') != -1): error(filename, linenum, 'caffe/data_layer_setup', 2, @@ -1622,6 +1623,7 @@ def CheckCaffeDataLayerSetUp(filename, clean_lines, linenum, error): line.find('void Base') == -1 and line.find('void DataLayer::DataLayerSetUp') == -1 and line.find('void ImageDataLayer::DataLayerSetUp') == -1 and + line.find('void VideoDataLayer::DataLayerSetUp') == -1 and line.find('void MemoryDataLayer::DataLayerSetUp') == -1 and line.find('void WindowDataLayer::DataLayerSetUp') == -1): error(filename, linenum, 'caffe/data_layer_setup', 2, diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 48d4e072..1a250bdb 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -34,15 +34,8 @@ void Blob::Reshape(const int num, const int channels, const int height, template void Blob::Reshape(const vector& shape) { - //std::cout << "-----------------------------"<::FromProto(const BlobProto& proto, bool reshape) { data_vec[i] = proto.double_data(i); } } else { - CHECK_EQ(count_, proto.data_size()); - for (int i = 0; i < count_; ++i) { - data_vec[i] = proto.data(i); + // normal case + if (!is_legacy_C3D_proto) { + CHECK_EQ(count_, proto.data_size()); + for (int i = 0; i < count_; ++i) { + data_vec[i] = proto.data(i); + } + // binary proto file created by legacy C3D code + } else { + for (int i = 0; i < count_; ++i) { + data_vec[i] = proto.diff(i); + } } } if (proto.double_diff_size() > 0) { @@ -551,7 +562,7 @@ void Blob::FromProto(const BlobProto& proto, bool reshape) { for (int i = 0; i < count_; ++i) { diff_vec[i] = proto.double_diff(i); } - } else if (proto.diff_size() > 0) { + } else if (proto.diff_size() > 0 && !is_legacy_C3D_proto) { CHECK_EQ(count_, proto.diff_size()); Dtype* diff_vec = mutable_cpu_diff(); for (int i = 0; i < count_; ++i) { diff --git a/src/caffe/layers/base_conv_layer.cpp b/src/caffe/layers/base_conv_layer.cpp index cf280e32..74ce0012 100644 --- a/src/caffe/layers/base_conv_layer.cpp +++ b/src/caffe/layers/base_conv_layer.cpp @@ -17,13 +17,11 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, channel_axis_ = bottom[0]->CanonicalAxisIndex(conv_param.axis()); const int num_axes = bottom[0]->num_axes(); int first_spatial_axis; - if (num_axes == 5 && channel_axis_ == 1 && - bottom[0]->shape(2) == 1) { + if (num_axes == 5 && channel_axis_ == 1 && bottom[0]->shape(2) == 1) { forced_3d_ = true; - first_spatial_axis = 3; // not 2 - num_spatial_axes_ = 2; // not 3 - } - else { + first_spatial_axis = 3; // not 2 + num_spatial_axes_ = 2; // not 3 + } else { forced_3d_ = false; first_spatial_axis = channel_axis_ + 1; num_spatial_axes_ = num_axes - first_spatial_axis; @@ -47,22 +45,21 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_ || num_kernel_dims == num_spatial_axes_ + 1) - << "kernel_size must be specified once, or once per spatial dimension " - << "(kernel_size specified " << num_kernel_dims << " times; " - << num_spatial_axes_ << " spatial dims)."; + << "kernel_size must be specified once, or once per spatial " + << "dimension (kernel_size specified " << num_kernel_dims + << " times; " << num_spatial_axes_ << " spatial dims)."; } else { CHECK(num_kernel_dims == 1 || num_kernel_dims == num_spatial_axes_) - << "kernel_size must be specified once, or once per spatial dimension " - << "(kernel_size specified " << num_kernel_dims << " times; " - << num_spatial_axes_ << " spatial dims)."; + << "kernel_size must be specified once, or once per spatial " + << "dimension (kernel_size specified " << num_kernel_dims + << " times; " << num_spatial_axes_ << " spatial dims)."; } if (num_kernel_dims == 1) { for (int i = 0; i < num_spatial_axes_; ++i) { kernel_shape_data[i] = conv_param.kernel_size(0); } - } - else if (num_kernel_dims == num_spatial_axes_) { + } else if (num_kernel_dims == num_spatial_axes_) { for (int i = 0; i < num_spatial_axes_; ++i) { kernel_shape_data[i] = conv_param.kernel_size(i); @@ -71,7 +68,7 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, if (num_kernel_dims == num_spatial_axes_ + 1) { for (int i = 0; i < num_spatial_axes_; ++i) { kernel_shape_data[i] = - conv_param.kernel_size(i + 1); // ignore the first kernel_size + conv_param.kernel_size(i + 1); // ignore the first kernel_size } } } @@ -113,7 +110,7 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, else if (num_stride_dims == num_spatial_axes_ ) stride_data[i] = conv_param.stride(i); else if (num_stride_dims == num_spatial_axes_ + 1) - stride_data[i] = conv_param.stride(i + 1); // ignore the first one + stride_data[i] = conv_param.stride(i + 1); // ignore the first one CHECK_GT(stride_data[i], 0) << "Stride dimensions must be nonzero."; } } @@ -152,7 +149,7 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, else if (num_pad_dims == num_spatial_axes_ ) pad_data[i] = conv_param.pad(i); else if (num_pad_dims == num_spatial_axes_ + 1) - pad_data[i] = conv_param.pad(i + 1); // ignore the first one + pad_data[i] = conv_param.pad(i + 1); // ignore the first one } } // Setup dilation dimensions (dilation_). @@ -182,7 +179,7 @@ void BaseConvolutionLayer::LayerSetUp(const vector*>& bottom, else if (num_dilation_dims == num_spatial_axes_ ) dilation_data[i] = conv_param.dilation(i); else if (num_dilation_dims == num_spatial_axes_ + 1) - dilation_data[i] = conv_param.dilation(i + 1); // ignore the first one + dilation_data[i] = conv_param.dilation(i + 1); // ignore the first one } // Special case: im2col is the identity for 1x1 convolution with stride 1 // and no padding, so flag for skipping the buffer and transformation. @@ -305,7 +302,8 @@ void BaseConvolutionLayer::Reshape(const vector*>& bottom, if (reverse_dimensions()) { conv_input_shape_data[i] = top[0]->shape(channel_axis_ + i + forced_3d_); } else { - conv_input_shape_data[i] = bottom[0]->shape(channel_axis_ + i + forced_3d_); + conv_input_shape_data[i] = bottom[0]->shape(channel_axis_ + i + + forced_3d_); } } // The im2col result buffer will only hold one image at a time to avoid diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index c69be445..129e8f7f 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -98,8 +98,6 @@ void CuDNNConvolutionLayer::Reshape( bottom_offset_ = this->bottom_dim_ / this->group_; top_offset_ = this->top_dim_ / this->group_; const bool forced_3d = this->forced_3d_; - //std::cout << "cudnn_conv_layer.cpp: num_spatial_axes="<num_spatial_axes_<num_axes()="<num_axes() << std::endl; - //std::cout << "caffe_conv:: weights[1]->num_axes()="<num_axes() << std::endl; - //std::cout << "caffe_conv:: in->num_axes()="<num_axes() << std::endl; - //std::cout << "caffe_conv:: out->num_axes()="<num_axes() << std::endl; const bool has_depth = (out->num_axes() == 5); - const bool forced_3d_ = has_depth && (in->shape(2) == 1); + const bool forced_3d = has_depth && (in->shape(2) == 1); if (!has_depth) { CHECK_EQ(4, out->num_axes()); } - //std::cout << "caffe_conv:: has_depth="<* in, ConvolutionParameter* conv_param, int k_g = in->shape(1) / groups; int o_head, k_head; // Convolution - vector weight_offset(4 + (has_depth && !forced_3d_)); + vector weight_offset(4 + (has_depth && !forced_3d)); vector in_offset(4 + has_depth); vector out_offset(4 + has_depth); Dtype* out_data = out->mutable_cpu_data(); @@ -96,29 +88,11 @@ void caffe_conv(const Blob* in, ConvolutionParameter* conv_param, if (in_z >= 0 && in_z < (has_depth ? in->shape(2) : 1) && in_y >= 0 && in_y < in->shape(2 + has_depth) && in_x >= 0 && in_x < in->shape(3 + has_depth)) { -/* - std::cout << - "o="<num()="<blob_top_->num()<<", "<< - "blob_top_->channels()="<blob_top_->channels()<<", "<< - "blob_top_->height()="<blob_top_->height()<<", "<< - "blob_top_->width()="<blob_top_->width()<<", "<< - std::endl; const Dtype* top_data = this->blob_top_->cpu_data(); for (int n = 0; n < this->blob_top_->num(); ++n) { for (int c = 0; c < this->blob_top_->channels(); ++c) { diff --git a/src/caffe/test/test_embed_layer.cpp b/src/caffe/test/test_embed_layer.cpp index ee1e9ec9..3fc33327 100644 --- a/src/caffe/test/test_embed_layer.cpp +++ b/src/caffe/test/test_embed_layer.cpp @@ -47,7 +47,7 @@ TYPED_TEST(EmbedLayerTest, TestSetUp) { layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); ASSERT_EQ(this->blob_top_->num_axes(), kInputDim); EXPECT_EQ(this->blob_top_->shape(0), 4); - for (int i = 1; i < kInputDim - 1; ++ i) + for (int i = 1; i < kInputDim - 1; ++i) EXPECT_EQ(this->blob_top_->shape(i), 1); EXPECT_EQ(this->blob_top_->shape(kInputDim - 1), 10); } diff --git a/src/caffe/test/test_io.cpp b/src/caffe/test/test_io.cpp index a772e8f7..a1ebf8e4 100644 --- a/src/caffe/test/test_io.cpp +++ b/src/caffe/test/test_io.cpp @@ -5,6 +5,7 @@ #include #include +#include #include "gtest/gtest.h" @@ -425,11 +426,11 @@ TEST_F(IOTest, TestReadVideoToCVMatBasic) { "caffe/test/test_data/youtube_objects_dog_v0002_s006"; std::vector cv_imgs; bool read_video_result = ReadVideoToCVMat(path, - 1, // start frame - 16, // length (# frames) - 0, // new height - 0, // new width - true, // load as color + 1, // start frame + 16, // length (# frames) + 0, // new height + 0, // new width + true, // load as color &cv_imgs); EXPECT_EQ(read_video_result, true); EXPECT_EQ(cv_imgs.size(), 16); @@ -443,13 +444,13 @@ TEST_F(IOTest, TestReadVideoToCVMatNotEnoughFrames) { "caffe/test/test_data/youtube_objects_dog_v0002_s006"; std::vector cv_imgs; bool read_video_result = ReadVideoToCVMat(path, - 40, // start frame - 16, // length (# frames) - 0, // new height - 0, // new width - true, // load as color + 2, // start frame + 16, // length (# frames) + 0, // new height + 0, // new width + true, // load as color &cv_imgs); - EXPECT_EQ(read_video_result, false); // because there are only 48 frames + EXPECT_EQ(read_video_result, false); // because there are only 16 frames } TEST_F(IOTest, TestReadVideoToCVMatResize) { @@ -457,11 +458,11 @@ TEST_F(IOTest, TestReadVideoToCVMatResize) { "caffe/test/test_data/youtube_objects_dog_v0002_s006"; std::vector cv_imgs; bool read_video_result = ReadVideoToCVMat(path, - 1, // start frame - 16, // length (# frames) - 80, // new height - 100, // new width - true, // load as color + 1, // start frame + 16, // length (# frames) + 80, // new height + 100, // new width + true, // load as color &cv_imgs); EXPECT_EQ(read_video_result, true); EXPECT_EQ(cv_imgs.size(), 16); @@ -475,11 +476,11 @@ TEST_F(IOTest, TestReadVideoToCVMatFromAviBasic) { "caffe/test/test_data/UCF-101_Rowing_g16_c03.avi"; std::vector cv_imgs; bool read_video_result = ReadVideoToCVMat(path, - 1, // start frame - 19, // length (# frames) - 0, // new height - 0, // new width - true, // load as color + 1, // start frame + 19, // length (# frames) + 0, // new height + 0, // new width + true, // load as color &cv_imgs); EXPECT_EQ(read_video_result, true); EXPECT_EQ(cv_imgs.size(), 19); @@ -493,11 +494,11 @@ TEST_F(IOTest, TestReadVideoToCVMatFromAviResize) { "caffe/test/test_data/UCF-101_Rowing_g16_c03.avi"; std::vector cv_imgs; bool read_video_result = ReadVideoToCVMat(path, - 1, // start frame - 19, // length (# frames) + 1, // start frame + 19, // length (# frames) 123, // new height 300, // new width - true, // load as color + true, // load as color &cv_imgs); EXPECT_EQ(read_video_result, true); EXPECT_EQ(cv_imgs.size(), 19); @@ -511,11 +512,11 @@ TEST_F(IOTest, TestReadVideoToCVMatFromAviResizeAndGrayscale) { "caffe/test/test_data/UCF-101_Rowing_g16_c03.avi"; std::vector cv_imgs; bool read_video_result = ReadVideoToCVMat(path, - 1, // start frame - 16, // length (# frames) - 80, // new height + 1, // start frame + 16, // length (# frames) + 80, // new height 100, // new width - false, // load as color + false, // load as color &cv_imgs); EXPECT_EQ(read_video_result, true); EXPECT_EQ(cv_imgs.size(), 16); diff --git a/src/caffe/test/test_pooling_layer.cpp b/src/caffe/test/test_pooling_layer.cpp index 7f626cf3..e274851c 100644 --- a/src/caffe/test/test_pooling_layer.cpp +++ b/src/caffe/test/test_pooling_layer.cpp @@ -1215,7 +1215,6 @@ class CudnnNdPoolingLayerTest : public GPUDeviceTest { void TestForwardSquare() { LayerParameter layer_param; PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); - //pooling_param->set_kernel_size(2); pooling_param->mutable_kernel_shape()->add_dim(1); pooling_param->mutable_kernel_shape()->add_dim(2); pooling_param->mutable_kernel_shape()->add_dim(2); @@ -1418,8 +1417,6 @@ class CudnnNdPoolingLayerTest : public GPUDeviceTest { void TestForwardRectWide() { LayerParameter layer_param; PoolingParameter* pooling_param = layer_param.mutable_pooling_param(); - //pooling_param->set_kernel_h(2); - //pooling_param->set_kernel_w(3); pooling_param->mutable_kernel_shape()->add_dim(1); pooling_param->mutable_kernel_shape()->add_dim(2); pooling_param->mutable_kernel_shape()->add_dim(3); @@ -1755,7 +1752,6 @@ TYPED_TEST(CudnnNdPoolingLayerTest, TestSetupCudnnNd3d) { vector blob_shape(5); const int num = 2; const int channels = 2; - //this->blob_bottom_->Reshape(num, channels, 1, 6, 5); blob_shape[0]=num; blob_shape[1]=channels; blob_shape[2]=1; diff --git a/src/caffe/test/test_video_data_layer.cpp b/src/caffe/test/test_video_data_layer.cpp index ea2520f0..beaf6e86 100644 --- a/src/caffe/test/test_video_data_layer.cpp +++ b/src/caffe/test/test_video_data_layer.cpp @@ -34,7 +34,6 @@ class VideoDataLayerTest : public MultiDeviceTest { std::ofstream outfile(filename_.c_str(), std::ofstream::out); LOG(INFO) << "Using temporary file " << filename_; for (int i = 0; i < 5; ++i) { - //outfile << EXAMPLES_SOURCE_DIR "images/cat.jpg " << i; outfile << CMAKE_SOURCE_DIR "caffe/test/test_data/UCF-101_Rowing_g16_c03.avi " << "0 " << i; @@ -44,8 +43,6 @@ class VideoDataLayerTest : public MultiDeviceTest { MakeTempFilename(&filename_reshape_); std::ofstream reshapefile(filename_reshape_.c_str(), std::ofstream::out); LOG(INFO) << "Using temporary file " << filename_reshape_; - //reshapefile << EXAMPLES_SOURCE_DIR "images/cat.jpg " << 0; - //reshapefile << EXAMPLES_SOURCE_DIR "images/fish-bike.jpg " << 1; reshapefile << CMAKE_SOURCE_DIR "caffe/test/test_data/UCF-101_Rowing_g16_c03.avi " << "0 0"; @@ -103,15 +100,16 @@ TYPED_TEST(VideoDataLayerTest, TestResize) { video_data_param->set_batch_size(5); video_data_param->set_source(this->filename_.c_str()); video_data_param->set_new_length(16); - video_data_param->set_new_height(256); - video_data_param->set_new_width(256); + video_data_param->set_new_height(132); + video_data_param->set_new_width(123); video_data_param->set_shuffle(false); VideoDataLayer layer(param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_data_->num(), 5); EXPECT_EQ(this->blob_top_data_->channels(), 3); - EXPECT_EQ(this->blob_top_data_->height(), 256); - EXPECT_EQ(this->blob_top_data_->width(), 256); + EXPECT_EQ(this->blob_top_data_->length(), 16); + EXPECT_EQ(this->blob_top_data_->height(), 132); + EXPECT_EQ(this->blob_top_data_->width(), 123); EXPECT_EQ(this->blob_top_label_->num(), 5); // Go through the data twice for (int iter = 0; iter < 2; ++iter) { @@ -140,12 +138,14 @@ TYPED_TEST(VideoDataLayerTest, TestReshape) { layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_data_->num(), 1); EXPECT_EQ(this->blob_top_data_->channels(), 3); + EXPECT_EQ(this->blob_top_data_->length(), 16); EXPECT_EQ(this->blob_top_data_->height(), 240); EXPECT_EQ(this->blob_top_data_->width(), 320); // layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_data_->num(), 1); EXPECT_EQ(this->blob_top_data_->channels(), 3); + EXPECT_EQ(this->blob_top_data_->length(), 16); EXPECT_EQ(this->blob_top_data_->height(), 240); EXPECT_EQ(this->blob_top_data_->width(), 320); } diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index 50ac6959..a8f30a02 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -53,21 +53,6 @@ void im2col_gpu(const Dtype* data_im, const int channels, (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; int num_kernels = channels * height_col * width_col; // NOLINT_NEXT_LINE(whitespace/operators) -/* - std::cout<< - "height="< -#include -#include // Check if a given path is a regular file or a path -void check_path(const std::string& path, bool& is_file, bool& is_dir) -{ +void check_path(const std::string& path, bool* is_file, bool* is_dir) { struct stat path_stat; stat(path.c_str(), &path_stat); - is_file = S_ISREG(path_stat.st_mode); - is_dir = S_ISDIR(path_stat.st_mode); + *is_file = S_ISREG(path_stat.st_mode); + *is_dir = S_ISDIR(path_stat.st_mode); } const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. @@ -160,7 +160,7 @@ bool ReadVideoToCVMat(const string& path, // Check if path is a directory that holds extracted images from a video, // or a regular video file. bool is_video_file, is_path; - check_path(path, is_video_file, is_path); + check_path(path, &is_video_file, &is_path); if (!is_video_file && !is_path) { LOG(ERROR) << "Could not open or find file " << path; return false; @@ -173,9 +173,9 @@ bool ReadVideoToCVMat(const string& path, cv::VideoCapture cap; cap.open(path); - if (!cap.isOpened()){ - LOG(ERROR) << "Cannot open a video file=" << path; - return false; + if (!cap.isOpened()) { + LOG(ERROR) << "Cannot open a video file=" << path; + return false; } int num_frames = cap.get(CV_CAP_PROP_FRAME_COUNT); @@ -189,7 +189,6 @@ bool ReadVideoToCVMat(const string& path, cap.set(CV_CAP_PROP_POS_FRAMES, start_frame); for (size_t i = start_frame; i <= end_frame; ++i) { - //cap.set(CV_CAP_PROP_POS_FRAMES, i); cap.read(cv_img_origin); if (!cv_img_origin.data) { LOG(ERROR) << "Could not read frame=" << i << @@ -200,9 +199,8 @@ bool ReadVideoToCVMat(const string& path, // Force color if (is_color && cv_img_origin.channels() == 1) { cv::cvtColor(cv_img_origin, cv_img_origin, CV_GRAY2BGR); - } // Force grayscale - else if (!is_color && cv_img_origin.channels() == 3) { + } else if (!is_color && cv_img_origin.channels() == 3) { cv::cvtColor(cv_img_origin, cv_img_origin, CV_BGR2GRAY); } @@ -212,11 +210,12 @@ bool ReadVideoToCVMat(const string& path, cv_img = cv_img_origin; } cv_imgs->push_back(cv_img.clone()); + cv_img_origin.release(); } - } + cap.release(); // In case of a directory with extracted frames within - else { + } else { int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE); @@ -225,11 +224,9 @@ bool ReadVideoToCVMat(const string& path, char image_filename[256]; for (int i = start_frame; i <= end_frame; ++i) { - //sprintf(image_filename, "%s/%06d.jpg", path.c_str(), i); - sprintf(image_filename, "%s/image_%04d.jpg", path.c_str(), i); + snprintf(image_filename, sizeof(image_filename), "%s/image_%04d.jpg", + path.c_str(), i); cv_img_origin = cv::imread(image_filename, cv_read_flag); - //LOG(INFO) << "i=" << i << ", cv_img_origin.at(10,10)=" << (int) cv_img_origin.at(10,10); - //LOG(INFO)<<"Reading i="<