Skip to content

Commit

Permalink
Merge branch '730-tensor-reshape' into 'master'
Browse files Browse the repository at this point in the history
Resolve "tensor の reshape に可変長テンプレートを使用"

Closes #730

See merge request ricos/monolish!499
  • Loading branch information
fockl committed May 2, 2023
2 parents b0f4fe6 + 909ee2f commit 47b5aaa
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Unreleased
- Add mattens functions as blas-like operation <https://gitlab.ritc.jp/ricos/monolish/-/merge_requests/495> <https://github.com/ricosjp/monolish/issues/727>
- Add view1D of tensor_Dense <https://gitlab.ritc.jp/ricos/monolish/-/merge_requests/496> <https://github.com/ricosjp/monolish/issues/728>
- Add times/adds/axpy tests for view1D of matrix/tensor <https://gitlab.ritc.jp/ricos/monolish/-/merge_requests/498> <https://github.com/ricosjp/monolish/issues/729>
- Add variadic templates for reshape tensor <https://gitlab.ritc.jp/ricos/monolish/-/merge_requests/499> <https://github.com/ricosjp/monolish/issues/730>

### Changed
- Change operator[] as val.get()[] of value array <https://gitlab.ritc.jp/ricos/monolish/-/merge_requests/494> <https://github.com/ricosjp/monolish/issues/726>
Expand Down
95 changes: 94 additions & 1 deletion include/monolish/common/monolish_tensor_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,39 @@ template <typename Float> class tensor_Dense {
**/
[[nodiscard]] Float at(const std::vector<size_t> &pos) const;

/**
* @brief get element A[pos[0]][pos[1]]...
* @param pos std::vector position
* @return A[pos[0]][pos[1]]...
* @note
* - # of computation: 1
* - Multi-threading: false
* - GPU acceleration: false
**/
template <typename... Args>
[[nodiscard]] Float at(const std::vector<size_t> &pos, const size_t dim,
const Args... args) const {
std::vector<size_t> pos_copy = pos;
pos_copy.push_back(dim);
return this->at(pos_copy, args...);
};

/**
* @brief get element A[pos[0]][pos[1]]...
* @param pos std::vector position
* @return A[pos[0]][pos[1]]...
* @note
* - # of computation: 1
* - Multi-threading: false
* - GPU acceleration: false
**/
template <typename... Args>
[[nodiscard]] Float at(const size_t dim, const Args... args) const {
std::vector<size_t> pos(1);
pos[0] = dim;
return this->at(pos, args...);
};

/**
* @brief get element A[pos[0]][pos[1]]... (onlu CPU)
* @param pos std::vector position
Expand All @@ -301,6 +334,34 @@ template <typename Float> class tensor_Dense {
return static_cast<const tensor_Dense *>(this)->at(pos);
};

/**
* @brief get element A[pos[0]][pos[1]]... (onlu CPU)
* @param pos std::vector position
* @return A[pos[0]][pos[1]]...
* @note
* - # of computation: 1
* - Multi-threading: false
* - GPU acceleration: false
**/
template <typename... Args>
[[nodiscard]] Float at(const std::vector<size_t> &pos, const Args... args) {
return static_cast<const tensor_Dense *>(this)->at(pos, args...);
};

/**
* @brief get element A[pos[0]][pos[1]]... (onlu CPU)
* @param pos std::vector position
* @return A[pos[0]][pos[1]]...
* @note
* - # of computation: 1
* - Multi-threading: false
* - GPU acceleration: false
**/
template <typename... Args>
[[nodiscard]] Float at(const size_t dim, const Args... args) {
return static_cast<const tensor_Dense *>(this)->at(dim, args...);
};

/**
* @brief set element A[pos[0]][pos[1]]...
* @param pos std::vector position
Expand Down Expand Up @@ -584,7 +645,39 @@ template <typename Float> class tensor_Dense {
* - Multi-threading: false
* - GPU acceleration: false
**/
void reshape(const std::vector<size_t> &shape);
void reshape(const std::vector<int> &shape);

/**
* @brief Reshape tensor
* @param shape
* @note
* - # of computation: 1
* - Multi-threading: false
* - GPU acceleration: false
**/
template <typename... Args>
void reshape(const std::vector<int> &shape, const size_t dim,
const Args... args) {
std::vector<int> shape_copy = shape;
shape_copy.push_back(dim);
reshape(shape_copy, args...);
return;
}

/**
* @brief Reshape tensor
* @param shape
* @note
* - # of computation: 1
* - Multi-threading: false
* - GPU acceleration: false
**/
template <typename... Args> void reshape(const int dim, const Args... args) {
std::vector<int> shape(1);
shape[0] = dim;
reshape(shape, args...);
return;
}

/////////////////////////////////////////////

Expand Down
11 changes: 6 additions & 5 deletions src/utils/reshape/reshape_tensor_dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@ namespace monolish {
namespace tensor {

template <typename T>
void tensor_Dense<T>::reshape(const std::vector<size_t> &shape) {
void tensor_Dense<T>::reshape(const std::vector<int> &shape) {
Logger &logger = Logger::get_instance();
logger.util_in(monolish_func);

int minus = 0;
size_t N = 1;
this->shape.resize(shape.size());
for (size_t index = 0; index < shape.size(); ++index) {
if (shape[index] < 0) {
minus++;
} else {
N *= shape[index];
this->shape[index] = shape[index];
}
this->shape[index] = shape[index];
}
if (minus >= 2) {
throw std::runtime_error("negative value of shape should be 0 or 1");
Expand All @@ -27,7 +28,7 @@ void tensor_Dense<T>::reshape(const std::vector<size_t> &shape) {
if (minus == 1) {
std::size_t M = 1;
for (size_t index = 0; index < this->shape.size(); ++index) {
if (this->shape[index] < 0) {
if (shape[index] < 0) {
this->shape[index] = get_nnz() / N;
}
M *= this->shape[index];
Expand All @@ -42,8 +43,8 @@ void tensor_Dense<T>::reshape(const std::vector<size_t> &shape) {

logger.util_out();
}
template void tensor_Dense<double>::reshape(const std::vector<size_t> &shape);
template void tensor_Dense<float>::reshape(const std::vector<size_t> &shape);
template void tensor_Dense<double>::reshape(const std::vector<int> &shape);
template void tensor_Dense<float>::reshape(const std::vector<int> &shape);

} // namespace tensor
} // namespace monolish
53 changes: 53 additions & 0 deletions test/tensor/tensor_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,52 @@ template <typename T> bool fixed_size_test() {
return true;
}

template <typename T> bool reshape_test() {
monolish::tensor::tensor_Dense<T> tensor_dense(
{2, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});

T val_orig = tensor_dense.at(1, 2);
if (val_orig != 9) {
std::cout << "original value failed" << std::endl;
return false;
}

tensor_dense.reshape(3, 4);
T val_new = tensor_dense.at(2, 0);
if (val_new != val_orig) {
std::cout << "reshaped value failed" << std::endl;
return false;
}

tensor_dense.reshape(2, 2, 3);
T val_new2 = tensor_dense.at(1, 0, 2);
if (val_new2 != val_orig) {
std::cout << "reshaped value2 failed" << std::endl;
return false;
}

tensor_dense.reshape(4, -1);
T val_new3 = tensor_dense.at(2, 2);
if (val_new3 != val_orig) {
std::cout << "reshaped value3 failed" << std::endl;
std::cout << val_new3 << "!=" << val_orig << std::endl;
return false;
}

tensor_dense.reshape(3, -1, 2);
T val_new4 = tensor_dense.at(2, 0, 0);
if (val_new4 != val_orig) {
std::cout << "reshaped value4 failed" << std::endl;
std::cout << val_new4 << "!=" << val_orig << std::endl;
return false;
}

std::cout << "Pass reshape test in " << get_type<T>() << " precision"
<< std::endl;

return true;
}

int main(int argc, char **argv) {

print_build_info();
Expand All @@ -325,5 +371,12 @@ int main(int argc, char **argv) {
return 3;
}

if (!reshape_test<double>()) {
return 4;
}
if (!reshape_test<float>()) {
return 4;
}

return 0;
}

0 comments on commit 47b5aaa

Please sign in to comment.