TLIB(TTM) is C++ high-performance tensor-matrix multiplication header-only library.
It provides free C++ functions for parallel computing the mode-q
tensor-times-matrix product of the general form
where
Please have a look at the wiki page for more informations about library usage, function interfaces and the parameters settings.
- Contraction mode
$q$ , tensor order$p$ , tensor extents$n$ and tensor layout$\mathbf{\pi}$ can be chosen at runtime - Supports any linear tensor layout inlcuding the first-order and last-order storage layouts
- Offers two high-level and one C-like low-level interfaces for calling the tensor-times-matrix multiplication
- Implemented independent of a tensor data structure (can be used with
std::vector
andstd::array
) - Currently supports float and double
- Multi-threading support with OpenMP v4.5 or higher
- Currently must be used with a BLAS implementation
- Performs in-place operations without transposing the tensor - no extra memory needed
- For large tensors reaches peak matrix-times-matrix performance
- Requires the tensor elements to be contiguously stored in memory
- Element types must be either
float
ordouble
import numpy as np
import ttmpy as tp
A = np.arange(4*3*2, dtype=np.float64).reshape(4,3,2)
B = np.arange(5*4, dtype=np.float64).reshape(5,4)
C = tp.ttm(1,A,B)
D = np.einsum("ijk,xi->xjk", A, B)
np.all(np.equal(C,D))
/*main.cpp*/
#include <tlib/ttm.h>
#include <vector>
#include <numeric>
#include <iostream>
int main()
{
using value_t = float;
using tensor_t = tlib::tensor<value_t>;
auto A = tensor_t( {4,3,2} );
auto B = tensor_t( {5,4} );
std::iota(A.begin(),A.end(),1);
std::fill(B.begin(),B.end(),1);
std::cout << "A = " << A << std::endl;
std::cout << "B = " << B << std::endl;
/*
A =
{ 1 5 9 | 13 17 21
2 6 10 | 14 18 22
3 7 11 | 15 19 23
4 8 12 | 16 20 24 };
B =
{ 1 1 1 1 1
1 1 1 1 1
1 1 1 1 1
1 1 1 1 1};
*/
auto C = A (1)* B;
std::cout << "C = " << C << std::endl;
/* for q=1
C =
{ 1+..+4 5+..+8 9+..+12 | 13+..+16 17+..+20 21+..+24
.. .. .. | .. .. ..
1+..+4 5+..+8 9+..+12 | 13+..+16 17+..+20 21+..+24 };
*/
}