diff --git a/Compiler/types.py b/Compiler/types.py index 6aa31c4e5..29254268a 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -6322,6 +6322,28 @@ def sort(self, n_threads=None, batcher=False, n_bits=None): from . import sorting sorting.radix_sort(self, self, n_bits=n_bits) + def to_row_matrix(self): + """ + Returns the array as 1xN matrix. + + Warning: This operation is in-place (without copying data), i.e., all changes to the values of the matrix will also affect the original array. + :return: Matrix + """ + assert self.value_type.n_elements() == 1 and \ + self.value_type.mem_size() == 1 + return Matrix(1, self.length, self.value_type, address=self.address) + + def to_column_matrix(self): + """ + Returns the array as Nx1 matrix. + + Warning: This operation is in-place (without copying data), i.e., all changes to the values of the matrix will also affect the original array. + :return: Matrix + """ + assert self.value_type.n_elements() == 1 and \ + self.value_type.mem_size() == 1 + return Matrix(self.length, 1, self.value_type, address=self.address) + def Array(self, size): # compatibility with registers return Array(size, self.value_type) diff --git a/Programs/Source/test_dot.mpc b/Programs/Source/test_dot.mpc index 92f0bad0c..58216b755 100644 --- a/Programs/Source/test_dot.mpc +++ b/Programs/Source/test_dot.mpc @@ -49,35 +49,32 @@ def test_matrix(expected, actual): crash() -break_point() -def hacky_array_dot_matrix(arr, mat): - # Arrays sadly do not have a dot function, therefore the array is converted into a 1 times n Matrix by copying memory addresses. - tmp = sint.Matrix(rows=1, columns=len(arr), address=arr.address) - result = tmp.dot(mat) - return sint.Array(mat.shape[1], result.address) - start_timer(3) -e3 = hacky_array_dot_matrix(a, c) +e3 = a.to_row_matrix().dot(c).to_array() # b[0] = e3[0] -f3 = hacky_array_dot_matrix(b, d) +f3 = b.to_row_matrix().dot(d).to_array() +g3 = c.dot(b.to_column_matrix()).to_array() stop_timer(3) e3 = e3.reveal() f3 = f3.reveal() +g3 = g3.reveal() e3.print_reveal_nested() f3.print_reveal_nested() +g3.print_reveal_nested() test_array([70, 80, 90], e3) test_array([56, 50, 44, 38], f3) +test_array([10, 28, 46, 64], g3) start_timer(4) -e4 = hacky_array_dot_matrix(a, c) +e4 = a.to_row_matrix().dot(c).to_array() b[-1] = e4[0] -f4 = hacky_array_dot_matrix(b, d) +f4 = b.to_row_matrix().dot(d).to_array() stop_timer(4)