Skip to content

Commit

Permalink
Merge pull request #1417 from vincent-ehrmanntraut/mergeable-matmulsm…
Browse files Browse the repository at this point in the history
…-squashed

Make matmulsm mergeable (Fixes #1407)
  • Loading branch information
mkskeller authored Jul 2, 2024
2 parents a271f54 + c745269 commit 41999a3
Show file tree
Hide file tree
Showing 10 changed files with 473 additions and 114 deletions.
71 changes: 64 additions & 7 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,70 @@ def keep_text_order(inst, n):
keep_text_order(instr, n)
elif isinstance(instr, RawInputInstruction):
keep_merged_order(instr, n, RawInputInstruction)
elif isinstance(instr, matmulsm):
if options.preserve_mem_order:
strict_mem_access(n, last_mem_read, last_mem_write)
else:
if instr.indices_values is not None and instr.first_factor_base_addresses is not None and instr.second_factor_base_addresses is not None:
# Determine which values get accessed by the MATMULSM instruction and only add the according dependencies.
for matmul_idx in range(len(instr.first_factor_base_addresses)):
start_time = time.time()
first_base = instr.first_factor_base_addresses[matmul_idx]
second_base = instr.second_factor_base_addresses[matmul_idx]

first_factor_row_indices = instr.indices_values[4 * matmul_idx]
first_factor_column_indices = instr.indices_values[4 * matmul_idx + 1]
second_factor_row_indices = instr.indices_values[4 * matmul_idx + 2]
second_factor_column_indices = instr.indices_values[4 * matmul_idx + 3]

first_factor_row_length = instr.args[12 * matmul_idx + 10]
second_factor_row_length = instr.args[12 * matmul_idx + 11]

# Due to the potentially very large number of inputs on large matrices, adding dependencies to
# all inputs may take a long time. Therefore, we only partially build the dependencies on
# large matrices and output a warning.
# The threshold of 2_250_000 values per matrix is equivalent to multiplying two 1500x1500
# matrices. Experiments showed that multiplying two 1700x1700 matrices requires roughly 10 seconds on an i7-1370P,
# so this threshold should lead to acceptable compile times even on slower processors.
first_factor_total_number_of_values = instr.args[12 * matmul_idx + 3] * instr.args[12 * matmul_idx + 4]
second_factor_total_number_of_values = instr.args[12 * matmul_idx + 4] * instr.args[12 * matmul_idx + 5]
max_dependencies_per_matrix = 1500**2
if first_factor_total_number_of_values > max_dependencies_per_matrix or second_factor_total_number_of_values > max_dependencies_per_matrix:
if block.warn_about_mem and not block.parent.warned_about_mem:
print('WARNING: Order of memory instructions not preserved due to long vector, errors possible')
block.parent.warned_about_mem = True

# Add dependencies to the first factor.
# If the size of the matrix exceeds the max_dependencies_per_matrix, only a limited number
# of rows will be processed.
for i in range(min(instr.args[12 * matmul_idx + 3], max_dependencies_per_matrix // instr.args[12 * matmul_idx + 4] + 1)):
for k in range(instr.args[12 * matmul_idx + 4]):
first_factor_addr = first_base + \
first_factor_row_length * first_factor_row_indices[i] + \
first_factor_column_indices[k]
handle_mem_access(first_factor_addr, 's', last_mem_read_of, last_mem_write_of)

# Add dependencies to the second factor.
# If the size of the matrix exceeds the max_dependencies_per_matrix, only a limited number
# of rows will be processed.
for k in range(min(instr.args[12 * matmul_idx + 4], max_dependencies_per_matrix // instr.args[12 * matmul_idx + 5] + 1)):
if (time.time() - start_time) > 10:
# Abort building the dependencies if that takes too much time.
if block.warn_about_mem and not block.parent.warned_about_mem:
print('WARNING: Order of memory instructions not preserved due to long vector, errors possible')
block.parent.warned_about_mem = True
break

for j in range(instr.args[12 * matmul_idx + 5]):
second_factor_addr = second_base + \
second_factor_row_length * second_factor_row_indices[k] + \
second_factor_column_indices[j]
handle_mem_access(second_factor_addr, 's', last_mem_read_of, last_mem_write_of)
else:
# If the accessed values cannot be determined, be cautious I guess.
for i in last_mem_write_of.values():
for j in i:
add_edge(j, n)

if isinstance(instr, merge_classes):
open_nodes.add(n)
Expand Down Expand Up @@ -622,13 +686,6 @@ def keep_text_order(inst, n):
strict_mem_access(n, scope.write, scope.read)
if not options.preserve_mem_order:
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
elif isinstance(instr, matmulsm):
if options.preserve_mem_order:
strict_mem_access(n, last_mem_read, last_mem_write)
else:
for i in last_mem_write_of.values():
for j in i:
add_edge(j, n)
# keep I/O instructions in order
elif isinstance(instr, IOInstruction):
if last_print_str is not None:
Expand Down
50 changes: 35 additions & 15 deletions Compiler/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,7 +2484,7 @@ def get_repeat(self):
return sum(reduce(operator.mul, self.args[i + 3:i + 6])
for i in range(0, len(self.args), 6))

class matmulsm(matmul_base):
class matmulsm(matmul_base, base.Mergeable):
""" Secret matrix multiplication reading directly from memory.
:param: result (sint vector in row-first order)
Expand All @@ -2494,26 +2494,46 @@ class matmulsm(matmul_base):
:param: number of columns in first factor and rows in second factor (int)
:param: number of columns in second factor and result (int)
:param: rows of first factor to use (regint vector, length as number of rows in first factor)
:param: columns of first factor to use (regint vector, length below)
:param: rows of second factor to use (regint vector, length below)
:param: columns of second factor to use (regint vector, length below)
:param: number of columns of first / rows of second factor to use (int)
:param: number of columns of second factor to use (int)
:param: columns of first factor to use (regint vector, length as number of columns in the first factor)
:param: rows of second factor to use (regint vector, length as number of columns in the first factor)
:param: columns of second factor to use (regint vector, length as number of columns in the second factor)
:param: total number of columns in the first factor, equal to used number of columns when all columns are used (int)
:param: total number of columns in the second factor, equal to used number of columns when all columns are used (int)
"""
code = base.opcodes['MATMULSM']
arg_format = ['sw','ci','ci','int','int','int','ci','ci','ci','ci',
'int','int']

def __init__(self, *args, **kwargs):
arg_format = itertools.cycle(['sw','ci','ci','int','int','int','ci','ci','ci','ci',
'int','int'])

def __init__(self, *args,
first_factor_base_addresses=None,
second_factor_base_addresses=None,
indices_values=None,
**kwargs):
matmul_base.__init__(self, *args, **kwargs)
for i in range(2):
assert args[6 + i].size == args[3 + i]
for i in range(2):
assert args[8 + i].size == args[4 + i]
for matmul_index in range(len(args) // 12):
for i in range(2):
assert args[12 * matmul_index + 6 + i].size == args[12 * matmul_index + 3 + i]
for i in range(2):
assert args[12 * matmul_index + 8 + i].size == args[12 * matmul_index + 4 + i]

# These are used to reconstruct that accessed memory addresses in the allocator.
self.first_factor_base_addresses = first_factor_base_addresses
self.second_factor_base_addresses = second_factor_base_addresses
self.indices_values = indices_values

if first_factor_base_addresses is not None:
assert len(first_factor_base_addresses) == len(second_factor_base_addresses)
if indices_values is not None:
assert len(indices_values) == 4 * len(first_factor_base_addresses)

def add_usage(self, req_node):
super(matmulsm, self).add_usage(req_node)
req_node.increment(('matmul', tuple(self.args[3:6])), 1)
for i in range(0, len(self.args), 12):
req_node.increment(('matmul', (self.args[i + 3], self.args[i + 4], self.args[i + 5])), 1)

def get_repeat(self):
return sum(reduce(operator.mul, self.args[i + 3:i + 6])
for i in range(0, len(self.args), 12))

class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
""" Secret 2D convolution.
Expand Down
16 changes: 14 additions & 2 deletions Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2668,12 +2668,24 @@ def store_in_mem(self, address):
self._store_in_mem(address, stms, stmsi)

@classmethod
def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None):
def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None, indices_values=None):
if indices is None:
indices = [regint.inc(i) for i in (n, m, m, l)]
indices_values = [list(range(i)) for i in (n, m, m, l)]
res = cls(size=indices[0].size * indices[3].size)

if isinstance(A, int) and isinstance(B, int):
first_factor_base_addresses = [A]
second_factor_base_addresses = [B]
else:
first_factor_base_addresses = None
second_factor_base_addresses = None

matmulsm(res, regint(A), regint(B), len(indices[0]), len(indices[1]),
len(indices[3]), *(list(indices) + [m, l]))
len(indices[3]), *(list(indices) + [m, l]),
first_factor_base_addresses=first_factor_base_addresses,
second_factor_base_addresses=second_factor_base_addresses,
indices_values=indices_values)
return res

@vectorize_init
Expand Down
7 changes: 3 additions & 4 deletions Processor/Instruction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
get_vector(num_var_args, start, s);
break;
case MATMULSM:
get_ints(r, s, 3);
get_vector(9, start, s);
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
break;

// read from file, input is opcode num_args,
Expand Down Expand Up @@ -1117,8 +1117,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.Procp.matmuls(Proc.Procp.get_S(), *this);
return;
case MATMULSM:
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this,
Proc.read_Ci(r[1]), Proc.read_Ci(r[2]));
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this);
return;
case CONV2DS:
Proc.Procp.protocol.conv2ds(Proc.Procp, *this);
Expand Down
8 changes: 6 additions & 2 deletions Processor/Processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,12 @@ class SubProcessor
void mulrs(const vector<int>& reg);
void dotprods(const vector<int>& reg, int size);
void matmuls(const vector<T>& source, const Instruction& instruction);
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction, size_t a,
size_t b);
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction);

void matmulsm_finalize_batch(vector<int>::const_iterator startMatmul, int startI, int startJ,
vector<int>::const_iterator endMatmul,
int endI, int endJ);

void conv2ds(const Instruction& instruction);

void secure_shuffle(const Instruction& instruction);
Expand Down
Loading

0 comments on commit 41999a3

Please sign in to comment.