Skip to content

Commit

Permalink
added zero gating
Browse files Browse the repository at this point in the history
  • Loading branch information
Seah Kim authored and Seah Kim committed Oct 13, 2023
1 parent 125e72d commit 5a7d7ce
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 9 deletions.
111 changes: 108 additions & 3 deletions src/main/scala/gemmini/Mesh.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T,
val out_last = Output(Vec(meshColumns, Vec(tileColumns, Bool())))
})

// for zero-gating
val in_a_zero = Wire(Vec(meshRows, Vec(tileRows, Bool())))
val in_b_zero = Wire(Vec(meshColumns, Vec(tileColumns, Bool())))
val in_d_zero = Wire(Vec(meshColumns, Vec(tileColumns, Bool())))

for (i <- 0 until meshRows; j <- 0 until tileRows) {
in_a_zero(i)(j) := io.in_a(i)(j).asUInt === 0.U
}

for (i <- 0 until meshColumns; j <- 0 until tileColumns) {
in_b_zero(i)(j) := io.in_b(i)(j).asUInt === 0.U
}

for (i <- 0 until meshColumns; j <- 0 until tileColumns) {
in_d_zero(i)(j) := io.in_d(i)(j).asUInt === 0.U
}

// mesh(r)(c) => Tile at row r, column c
val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, tree_reduction, max_simultaneous_matmuls, tileRows, tileColumns)))
val meshT = mesh.transpose
Expand All @@ -47,24 +64,57 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T,

// Chain tile_a_out -> tile_a_in (pipeline a across each row)
// TODO clock-gate A signals with in_garbage
/*
for (r <- 0 until meshRows) {
mesh(r).foldLeft(io.in_a(r)) {
case (in_a, tile) =>
tile.io.in_a := ShiftRegister(in_a, tile_latency+1)
tile.io.out_a
}
}
*/
// added for zero-gating
for (r <- 0 until meshRows) {
mesh(r).foldLeft((io.in_a(r), in_a_zero(r))) {
case ((in_a, az), tile) =>
// tile.io.in_a := RegNext(in_a)

(tile.io.in_a, in_a, az).zipped.foreach { case (ina, a, z) =>
//ina := RegEnable(a, !z)
ina := ShiftRegister(a, tile_latency+1, !z)
}

(tile.io.out_a, tile.io.out_a_zero)
}
}

// Chain tile_out_b -> tile_b_in (pipeline b across each column)
// Chain tile_out_b -> tile_b_in (pipeline b across each column)
/*
for (c <- 0 until meshColumns) {
meshT(c).foldLeft((io.in_b(c), io.in_valid(c))) {
case ((in_b, valid), tile) =>
tile.io.in_b := pipe(valid.head, in_b, tile_latency+1)
(tile.io.out_b, tile.io.out_valid)
}
}
*/
// for zero-gating
for(c <- 0 until meshColumns){
meshT(c).foldLeft((io.in_b(c), io.in_valid(c), in_b_zero(c))) {
case ((in_b, valid, bz), tile) =>
// tile.io.in_b := RegEnable(in_b, valid.head)

(tile.io.in_b, in_b, bz).zipped.foreach { case (inb, b, z) =>
//inb := RegEnable(b, valid.head && !z)
inb := pipe(valid.head && !z, b, tile_latency + 1)
}

(tile.io.out_b, tile.io.out_valid, tile.io.out_b_zero)
}
}

// Chain tile_out -> tile_propag (pipeline output across each column)
/*
for (c <- 0 until meshColumns) {
meshT(c).foldLeft((io.in_d(c), io.in_valid(c))) {
case ((in_propag, valid), tile) =>
Expand All @@ -73,6 +123,24 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T,
}
}
*/

// for zero-gating
for(c <- 0 until meshColumns){
meshT(c).foldLeft((io.in_d(c), io.in_valid(c), in_d_zero(c))) {
case ((in_propag, valid, outz), tile) =>
// tile.io.in_d := RegEnable(in_propag, valid.head)

(tile.io.in_d, in_propag, outz).zipped.foreach { case (ind, prop, z) =>
//ind := RegEnable(prop, valid.head && !z)
ind := pipe(valid.head && !z, prop, tile_latency + 1)
}

(tile.io.out_c, tile.io.out_valid, tile.io.out_c_zero)
}
}


// Chain control signals (pipeline across each column)
assert(!(mesh.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_)))
for (c <- 0 until meshColumns) {
Expand Down Expand Up @@ -114,13 +182,50 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T,
}
}

// added for zero-gating
// Chain a_zero (pipeline across each column)
for (r <- 0 until meshRows) {
mesh(r).foldLeft(in_a_zero(r)) {
case (in_z, tile) =>
tile.io.in_a_zero := ShiftRegister(in_z, tile_latency+1)
tile.io.out_a_zero
}
}

// Chain b_zero (pipeline across each column)
for (c <- 0 until meshColumns) {
meshT(c).foldLeft(in_b_zero(c)) {
case (in_z, tile) =>
tile.io.in_b_zero := ShiftRegister(in_z, tile_latency+1)
tile.io.out_b_zero
}
}

// Chain d_zero (pipeline across each column)
for (c <- 0 until meshColumns) {
meshT(c).foldLeft(in_d_zero(c)) {
case (in_z, tile) =>
tile.io.in_d_zero := ShiftRegister(in_z, tile_latency+1)
tile.io.out_c_zero
}
}

// Capture out_vec and out_control_vec (connect IO to bottom row of mesh)
// (The only reason we have so many zips is because Scala doesn't provide a zipped function for Tuple4)
for (((((((b, c), v), ctrl), id), last), tile) <- io.out_b zip io.out_c zip io.out_valid zip io.out_control zip io.out_id zip io.out_last zip mesh.last) {
// TODO we pipelined this to make physical design easier. Consider removing these if possible
// TODO shouldn't we clock-gate these signals with "garbage" as well?
b := ShiftRegister(tile.io.out_b, output_delay)
c := ShiftRegister(tile.io.out_c, output_delay)
//b := ShiftRegister(tile.io.out_b, output_delay)
//c := ShiftRegister(tile.io.out_c, output_delay)

//added for zero-gating
b := ShiftRegister(VecInit(tile.io.out_b.zip(tile.io.out_b_zero).map { case (outb, outbz) =>
Mux(outbz, 0.U.asTypeOf(outb), outb)
}), output_delay)
c := ShiftRegister(VecInit(tile.io.out_c.zip(tile.io.out_c_zero).map { case (outc, outcz) =>
Mux(outcz, 0.U.asTypeOf(outc), outc)
}), output_delay)

v := ShiftRegister(tile.io.out_valid, output_delay)
ctrl := ShiftRegister(tile.io.out_control, output_delay)
id := ShiftRegister(tile.io.out_id, output_delay)
Expand Down
68 changes: 62 additions & 6 deletions src/main/scala/gemmini/PE.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value,
val out_valid = Output(Bool())

val bad_dataflow = Output(Bool())

// added for zero-gating

val in_a_zero = Input(Bool())
val in_b_zero = Input(Bool())
val in_d_zero = Input(Bool())
val out_a_zero = Output(Bool())
val out_b_zero = Output(Bool())
val out_c_zero = Output(Bool())
})

val cType = if (df == Dataflow.WS) inputType else accType
Expand All @@ -76,6 +85,14 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value,
val last = io.in_last
val valid = io.in_valid

// added for zero-gating
val a_zero = io.in_a_zero
val b_zero = io.in_b_zero
val d_zero = io.in_d_zero
val c1_zero = Reg(Bool())
val c2_zero = Reg(Bool())


io.out_a := a
io.out_control.dataflow := dataflow
io.out_control.propagate := prop
Expand All @@ -84,6 +101,10 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value,
io.out_last := last
io.out_valid := valid

// for zero-gating
io.out_a_zero := a_zero
io.out_c_zero := io.out_c.asUInt === 0.U

mac_unit.io.in_a := a

val last_s = RegEnable(prop, valid)
Expand Down Expand Up @@ -115,25 +136,45 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value,
c1 := mac_unit.io.out_d
c2 := d.withWidthOf(cType)
}
io.out_b_zero := b_zero
}.elsewhen ((df == Dataflow.WS).B || ((df == Dataflow.BOTH).B && dataflow === WEIGHT_STATIONARY)) {
when(prop === PROPAGATE) {
io.out_c := c1
mac_unit.io.in_b := c2.asTypeOf(inputType)
mac_unit.io.in_c := b
io.out_b := mac_unit.io.out_d
c1 := d
//mac_unit.io.in_c := b
//io.out_b := mac_unit.io.out_d
//c1 := d
//added for zero-gating
//io.out_b := Mux(c2_zero, b, Mux(b_zero, 0.U.asTypeOf(b), b).mac(Mux(a_zero, 0.U.asTypeOf(a), a), c2.asTypeOf(inputType)))
mac_unit.io.in_a := Mux(a_zero, 0.U.asTypeOf(a), a)
mac_unit.io.in_c := Mux(b_zero, 0.U.asTypeOf(b), b)
io.out_b := Mux(c2_zero, b, mac_unit.io.out_d)
c1 := Mux(d_zero, 0.U.asTypeOf(d), d)
c1_zero := d_zero
io.out_b_zero := Mux(c2_zero, b_zero, io.out_b.asUInt === 0.U)
}.otherwise {
io.out_c := c2
mac_unit.io.in_b := c1.asTypeOf(inputType)
mac_unit.io.in_c := b
io.out_b := mac_unit.io.out_d
c2 := d
//mac_unit.io.in_c := b
//io.out_b := mac_unit.io.out_d
//c2 := d
//added for zero-gating
//io.out_b := Mux(c1_zero, b, Mux(b_zero, 0.U.asTypeOf(b), b).mac(Mux(a_zero, 0.U.asTypeOf(a), a), c1.asTypeOf(inputType)))
mac_unit.io.in_a := Mux(a_zero, 0.U.asTypeOf(a), a)
mac_unit.io.in_c := Mux(b_zero, 0.U.asTypeOf(b), b)
io.out_b := Mux(c1_zero, b, mac_unit.io.out_d)
c2 := Mux(d_zero, 0.U.asTypeOf(d), d)
c2_zero := d_zero
io.out_b_zero := Mux(c1_zero, b_zero, io.out_b.asUInt === 0.U)
}
}.otherwise {
io.bad_dataflow := true.B
//assert(false.B, "unknown dataflow")
io.out_c := DontCare
io.out_b := DontCare
// add for zero-gating
io.out_b_zero := DontCare
io.out_c_zero := DontCare
mac_unit.io.in_b := b.asTypeOf(inputType)
mac_unit.io.in_c := c2
}
Expand All @@ -143,5 +184,20 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value,
c2 := c2
mac_unit.io.in_b := DontCare
mac_unit.io.in_c := DontCare
//addd for zero-gating
c1_zero := c1_zero
c2_zero := c2_zero
}
/*
when (io.in_a_zero || io.in_b_zero) {
c1 := c1
c2 := c2
mac_unit.io.in_b := DontCare
mac_unit.io.in_c := DontCare
//addd for zero-gating
//c1_zero := c1_zero
//c2_zero := c2_zero
}
*/
}
43 changes: 43 additions & 0 deletions src/main/scala/gemmini/Tile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu
val out_valid = Output(Vec(columns, Bool()))

val bad_dataflow = Output(Bool())

// added for zero-gating
val in_a_zero = Input(Vec(rows, Bool()))
val in_b_zero = Input(Vec(columns, Bool()))
val in_d_zero = Input(Vec(columns, Bool()))
val out_a_zero = Output(Vec(rows, Bool()))
val out_b_zero = Output(Vec(columns, Bool()))
val out_c_zero = Output(Vec(columns, Bool()))
})

import ev._
Expand Down Expand Up @@ -106,6 +114,35 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu
}
}

// added for zero-gating
// Broadcast 'a_zero' horizontally across the Tile
for (r <- 0 until columns) {
tile(r).foldLeft(io.in_a_zero(r)) {
case (z, pe) =>
pe.io.in_a_zero := z
pe.io.out_a_zero
}
}

// Broadcast 'b_zero' vertically across the Tile
for (c <- 0 until columns) {
tileT(c).foldLeft(io.in_b_zero(c)) {
case (z, pe) =>
pe.io.in_b_zero := z
pe.io.out_b_zero
}
}

// Broadcast 'd_zero' vertically across the Tile
for (c <- 0 until columns) {
tileT(c).foldLeft(io.in_d_zero(c)) {
case (z, pe) =>
pe.io.in_d_zero := z
pe.io.out_c_zero
}
}


// Drive the Tile's bottom IO
for (c <- 0 until columns) {
io.out_c(c) := tile(rows-1)(c).io.out_c
Expand All @@ -114,6 +151,10 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu
io.out_last(c) := tile(rows-1)(c).io.out_last
io.out_valid(c) := tile(rows-1)(c).io.out_valid

//added for zero-gating
io.out_b_zero(c) := tile(rows-1)(c).io.out_b_zero
io.out_c_zero(c) := tile(rows-1)(c).io.out_c_zero

io.out_b(c) := {
if (tree_reduction) {
val prods = tileT(c).map(_.io.out_b)
Expand All @@ -128,5 +169,7 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu
// Drive the Tile's right IO
for (r <- 0 until rows) {
io.out_a(r) := tile(r)(columns-1).io.out_a
//added for zero-gating
io.out_a_zero(r) := tile(r)(columns-1).io.out_a_zero
}
}

0 comments on commit 5a7d7ce

Please sign in to comment.