Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAPHNE-#811] Extension of EwBinaryMat and EwUnaryMat Kernels to Support Sparse and Dense Matrix Addition #812

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/runtime/local/datastructures/CSRMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class CSRMatrix : public Matrix<ValueType> {
// TODO Here we could reduce the allocated size of the values and
// colIdxs arrays.
}

ValueType * getValues() {
return values.get();
}
Expand Down
14 changes: 13 additions & 1 deletion src/runtime/local/datastructures/DenseMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,19 @@ class DenseMatrix : public Matrix<ValueType>
ValueType get(size_t rowIdx, size_t colIdx) const override {
return getValues()[pos(rowIdx, colIdx, isPartialBuffer())];
}


size_t getNumNonZeros() const {
size_t count = 0;
for (size_t r = 0; r < numRows; r++) {
for (size_t c = 0; c < numCols; c++) {
if (get(r, c) != ValueType(0)) {
count++;
}
}
}
return count;
}

void set(size_t rowIdx, size_t colIdx, ValueType value) override {
auto vals = getValues();
vals[pos(rowIdx, colIdx)] = value;
Expand Down
240 changes: 212 additions & 28 deletions src/runtime/local/kernels/EwBinaryMat.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <algorithm>
#include <runtime/local/context/DaphneContext.h>
#include <runtime/local/datastructures/CSRMatrix.h>
#include <runtime/local/datastructures/DataObjectFactory.h>
Expand Down Expand Up @@ -109,6 +110,148 @@ struct EwBinaryMat<DenseMatrix<VTres>, DenseMatrix<VTlhs>, DenseMatrix<VTrhs>> {
}
};

// ----------------------------------------------------------------------------
// DenseMatrix <- CSRMatrix, CSRMatrix
// ----------------------------------------------------------------------------

template<typename VT>
struct EwBinaryMat<DenseMatrix<VT>, CSRMatrix<VT>, CSRMatrix<VT>> {
static void apply(BinaryOpCode opCode, DenseMatrix<VT> *& res, const CSRMatrix<VT> * lhs, const CSRMatrix<VT> * rhs, DCTX(ctx)) {
const size_t numRows = lhs->getNumRows();
const size_t numCols = lhs->getNumCols();

if (numRows != rhs->getNumRows() || numCols != rhs->getNumCols()) {
throw std::runtime_error("EwBinaryMat(DenseMatrix <- CSRMatrix, CSRMatrix) - lhs and rhs must have the same dimensions.");
}

if (res == nullptr) {
res = DataObjectFactory::create<DenseMatrix<VT>>(numRows, numCols, false);
}

VT * valuesRes = res->getValues();

EwBinaryScaFuncPtr<VT, VT, VT> func = getEwBinaryScaFuncPtr<VT, VT, VT>(opCode);

std::fill(valuesRes, valuesRes + numRows * numCols, VT(0));

switch(opCode) {
case BinaryOpCode::ADD: {
for (size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
size_t nnzRowLhs = lhs->getNumNonZeros(rowIdx);
size_t nnzRowRhs = rhs->getNumNonZeros(rowIdx);

const VT* valuesRowLhs = lhs->getValues(rowIdx);
const size_t* colIdxsRowLhs = lhs->getColIdxs(rowIdx);

const VT* valuesRowRhs = rhs->getValues(rowIdx);
const size_t* colIdxsRowRhs = rhs->getColIdxs(rowIdx);

size_t posLhs = 0, posRhs = 0;

while (posLhs < nnzRowLhs || posRhs < nnzRowRhs) {
if (posLhs < nnzRowLhs && (posRhs >= nnzRowRhs || colIdxsRowLhs[posLhs] < colIdxsRowRhs[posRhs])) {
// Only lhs has a value in this column
valuesRes[rowIdx * numCols + colIdxsRowLhs[posLhs]] = func(valuesRowLhs[posLhs], VT(0), ctx);
posLhs++;
}
else if (posRhs < nnzRowRhs && (posLhs >= nnzRowLhs || colIdxsRowRhs[posRhs] < colIdxsRowLhs[posLhs])) {
// Only rhs has a value in this column
valuesRes[rowIdx * numCols + colIdxsRowRhs[posRhs]] = func(VT(0), valuesRowRhs[posRhs], ctx);
posRhs++;
}
else {
// Both lhs and rhs have values in this column
valuesRes[rowIdx * numCols + colIdxsRowLhs[posLhs]] = func(valuesRowLhs[posLhs], valuesRowRhs[posRhs], ctx);
posLhs++;
posRhs++;
}
}
}
break;
}
default:
throw std::runtime_error("EwBinaryMat(DenseMatrix <- CSRMatrix, CSRMatrix) - unsupported BinaryOpCode");
}
}
};


// ----------------------------------------------------------------------------
// DenseMatrix <- CSRMatrix, DenseMatrix
// ----------------------------------------------------------------------------

template<typename VT>
struct EwBinaryMat<DenseMatrix<VT>, CSRMatrix<VT>, DenseMatrix<VT>> {
static void apply(BinaryOpCode opCode, DenseMatrix<VT> *& res, const CSRMatrix<VT> * lhs, const DenseMatrix<VT> * rhs, DCTX(ctx)) {
const size_t numRows = lhs->getNumRows();
const size_t numCols = lhs->getNumCols();

if((numRows != rhs->getNumRows() && rhs->getNumRows() != 1) || (numCols != rhs->getNumCols() && rhs->getNumCols() != 1))
throw std::runtime_error("EwBinaryMat(Dense) - lhs and rhs must have the same dimensions (or broadcast)");

if(res == nullptr)
res = DataObjectFactory::create<DenseMatrix<VT>>(numRows, numCols, false);

VT * valuesRes = res->getValues();

EwBinaryScaFuncPtr<VT, VT, VT> func = getEwBinaryScaFuncPtr<VT, VT, VT>(opCode);

switch(opCode) {
case BinaryOpCode::ADD: { // Add operation
for(size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
auto rhsRow = (rhs->getNumRows() == 1 ? 0 : rowIdx);
for(size_t colIdx = 0; colIdx < numCols; colIdx++) {
auto lhsVal = lhs->get(rowIdx, colIdx);
auto rhsVal = rhs->get(rhsRow, colIdx);
valuesRes[rowIdx * numCols + colIdx] = func(lhsVal, rhsVal, ctx);
}
}
break;
}
default:
throw std::runtime_error("EwBinaryMat(Dense) - unsupported BinaryOpCode");
}
}
};

// ----------------------------------------------------------------------------
// DenseMatrix <- DenseMatrix, CSRMatrix
// ----------------------------------------------------------------------------

template<typename VT>
struct EwBinaryMat<DenseMatrix<VT>, DenseMatrix<VT>, CSRMatrix<VT>> {
static void apply(BinaryOpCode opCode, DenseMatrix<VT> *& res, const DenseMatrix<VT> * lhs, const CSRMatrix<VT> * rhs, DCTX(ctx)) {
const size_t numRows = lhs->getNumRows();
const size_t numCols = lhs->getNumCols();

if((numRows != rhs->getNumRows() && rhs->getNumRows() != 1) || (numCols != rhs->getNumCols() && rhs->getNumCols() != 1))
throw std::runtime_error("EwBinaryMat(Dense) - lhs and rhs must have the same dimensions (or broadcast)");

if(res == nullptr)
res = DataObjectFactory::create<DenseMatrix<VT>>(numRows, numCols, false);

VT * valuesRes = res->getValues();

EwBinaryScaFuncPtr<VT, VT, VT> func = getEwBinaryScaFuncPtr<VT, VT, VT>(opCode);

switch(opCode) {
case BinaryOpCode::ADD: { // Add operation
for(size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
auto lhsRow = (lhs->getNumRows() == 1 ? 0 : rowIdx);
for(size_t colIdx = 0; colIdx < numCols; colIdx++) {
auto lhsVal = lhs->get(lhsRow, colIdx);
auto rhsVal = rhs->get(rowIdx, colIdx);
valuesRes[rowIdx * numCols + colIdx] = func(lhsVal, rhsVal, ctx);
}
}
break;
}
default:
throw std::runtime_error("EwBinaryMat(Dense) - unsupported BinaryOpCode");
}
}
};

// ----------------------------------------------------------------------------
// CSRMatrix <- CSRMatrix, CSRMatrix
// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -269,6 +412,9 @@ struct EwBinaryMat<CSRMatrix<VT>, CSRMatrix<VT>, DenseMatrix<VT>> {
case BinaryOpCode::MUL: // intersect
maxNnz = lhs->getNumNonZeros();
break;
case BinaryOpCode::ADD:
maxNnz = lhs->getNumNonZeros() + rhs->getNumNonZeros();
break;
default:
throw std::runtime_error("EwBinaryMat(CSR) - unknown BinaryOpCode");
}
Expand All @@ -277,43 +423,81 @@ struct EwBinaryMat<CSRMatrix<VT>, CSRMatrix<VT>, DenseMatrix<VT>> {
res = DataObjectFactory::create<CSRMatrix<VT>>(numRows, numCols, maxNnz, false);

size_t *rowOffsetsRes = res->getRowOffsets();
rowOffsetsRes[0] = 0;

EwBinaryScaFuncPtr<VT, VT, VT> func = getEwBinaryScaFuncPtr<VT, VT, VT>(opCode);

rowOffsetsRes[0] = 0;

switch(opCode) {
case BinaryOpCode::MUL: { // intersect non-zero cells
for(size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
size_t nnzRowLhs = lhs->getNumNonZeros(rowIdx);
if(nnzRowLhs) {
// intersect within row
const VT * valuesRowLhs = lhs->getValues(rowIdx);
VT * valuesRowRes = res->getValues(rowIdx);
const size_t * colIdxsRowLhs = lhs->getColIdxs(rowIdx);
size_t * colIdxsRowRes = res->getColIdxs(rowIdx);
auto rhsRow = (rhs->getNumRows() == 1 ? 0 : rowIdx);
size_t posRes = 0;
for (size_t posLhs = 0; posLhs < nnzRowLhs; ++posLhs) {
auto rhsCol = (rhs->getNumCols() == 1 ? 0 : colIdxsRowLhs[posLhs]);
auto rVal = rhs->get(rhsRow, rhsCol);
if(rVal != 0) {
valuesRowRes[posRes] = func(valuesRowLhs[posLhs], rVal, ctx);
colIdxsRowRes[posRes] = colIdxsRowLhs[posLhs];
case BinaryOpCode::MUL: { // intersect non-zero cells
for(size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
size_t nnzRowLhs = lhs->getNumNonZeros(rowIdx);
if(nnzRowLhs) {
// intersect within row
const VT * valuesRowLhs = lhs->getValues(rowIdx);
VT * valuesRowRes = res->getValues(rowIdx);
const size_t * colIdxsRowLhs = lhs->getColIdxs(rowIdx);
size_t * colIdxsRowRes = res->getColIdxs(rowIdx);
auto rhsRow = (rhs->getNumRows() == 1 ? 0 : rowIdx);
size_t posRes = 0;
for (size_t posLhs = 0; posLhs < nnzRowLhs; ++posLhs) {
auto rhsCol = (rhs->getNumCols() == 1 ? 0 : colIdxsRowLhs[posLhs]);
auto rVal = rhs->get(rhsRow, rhsCol);
if(rVal != 0) {
valuesRowRes[posRes] = func(valuesRowLhs[posLhs], rVal, ctx);
colIdxsRowRes[posRes] = colIdxsRowLhs[posLhs];
posRes++;
}
}
rowOffsetsRes[rowIdx + 1] = rowOffsetsRes[rowIdx] + posRes;
}
else
// empty row in result
rowOffsetsRes[rowIdx + 1] = rowOffsetsRes[rowIdx];
}
break;
}
case BinaryOpCode::ADD: {
VT* valuesRes = res->getValues(0); // Pointer to the result values array
size_t* colIdxsRes = res->getColIdxs(0); // Pointer to the result column indices array
size_t posRes = 0; // Track position in the result matrix's values and colIdxs arrays

for (size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
const size_t rhsRow = (rhs->getNumRows() == 1 ? 0 : rowIdx);
size_t posLhs = 0; // Current position in LHS row
size_t nnzRowLhs = lhs->getNumNonZeros(rowIdx);
const VT* valuesRowLhs = lhs->getValues(rowIdx);
const size_t* colIdxsRowLhs = lhs->getColIdxs(rowIdx);

size_t colIdxRhs = 0; // Start from the first column index for the dense matrix

// Merge LHS and RHS
while (colIdxRhs < numCols) {
VT lhsVal = 0;
VT rhsVal = rhs->get(rhsRow, colIdxRhs);

// Check if there's a corresponding LHS value
if (posLhs < nnzRowLhs && colIdxsRowLhs[posLhs] == colIdxRhs) {
lhsVal = valuesRowLhs[posLhs++];
}

VT resultVal = func(lhsVal, rhsVal, ctx);

// Only store non-zero results
if (resultVal != 0) {
valuesRes[posRes] = resultVal;
colIdxsRes[posRes] = colIdxRhs;
posRes++;
}
colIdxRhs++;
}
rowOffsetsRes[rowIdx + 1] = rowOffsetsRes[rowIdx] + posRes;
rowOffsetsRes[rowIdx + 1] = posRes;
}
else
// empty row in result
rowOffsetsRes[rowIdx + 1] = rowOffsetsRes[rowIdx];
break;
}

default:
throw std::runtime_error("EwBinaryMat(CSR) - unknown BinaryOpCode");
}
break;
}
default:
throw std::runtime_error("EwBinaryMat(CSR) - unknown BinaryOpCode");
}

// TODO Update number of non-zeros in result in the end.
}
Expand Down
Loading