From 5a7d7ce1a5fc4aef73cd717b04a6e8edb5646200 Mon Sep 17 00:00:00 2001 From: Seah Kim Date: Fri, 13 Oct 2023 15:28:35 -0700 Subject: [PATCH] added zero gating --- src/main/scala/gemmini/Mesh.scala | 111 +++++++++++++++++++++++++++++- src/main/scala/gemmini/PE.scala | 68 ++++++++++++++++-- src/main/scala/gemmini/Tile.scala | 43 ++++++++++++ 3 files changed, 213 insertions(+), 9 deletions(-) diff --git a/src/main/scala/gemmini/Mesh.scala b/src/main/scala/gemmini/Mesh.scala index cd056658..bfab3e34 100644 --- a/src/main/scala/gemmini/Mesh.scala +++ b/src/main/scala/gemmini/Mesh.scala @@ -35,6 +35,23 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, val out_last = Output(Vec(meshColumns, Vec(tileColumns, Bool()))) }) + // for zero-gating + val in_a_zero = Wire(Vec(meshRows, Vec(tileRows, Bool()))) + val in_b_zero = Wire(Vec(meshColumns, Vec(tileColumns, Bool()))) + val in_d_zero = Wire(Vec(meshColumns, Vec(tileColumns, Bool()))) + + for (i <- 0 until meshRows; j <- 0 until tileRows) { + in_a_zero(i)(j) := io.in_a(i)(j).asUInt === 0.U + } + + for (i <- 0 until meshColumns; j <- 0 until tileColumns) { + in_b_zero(i)(j) := io.in_b(i)(j).asUInt === 0.U + } + + for (i <- 0 until meshColumns; j <- 0 until tileColumns) { + in_d_zero(i)(j) := io.in_d(i)(j).asUInt === 0.U + } + // mesh(r)(c) => Tile at row r, column c val mesh: Seq[Seq[Tile[T]]] = Seq.fill(meshRows, meshColumns)(Module(new Tile(inputType, outputType, accType, df, tree_reduction, max_simultaneous_matmuls, tileRows, tileColumns))) val meshT = mesh.transpose @@ -47,6 +64,7 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, // Chain tile_a_out -> tile_a_in (pipeline a across each row) // TODO clock-gate A signals with in_garbage +/* for (r <- 0 until meshRows) { mesh(r).foldLeft(io.in_a(r)) { case (in_a, tile) => @@ -54,8 +72,24 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, tile.io.out_a } } +*/ + // added for zero-gating + for (r <- 0 until meshRows) { + mesh(r).foldLeft((io.in_a(r), in_a_zero(r))) { + case ((in_a, az), tile) => + // tile.io.in_a := RegNext(in_a) + + (tile.io.in_a, in_a, az).zipped.foreach { case (ina, a, z) => + //ina := RegEnable(a, !z) + ina := ShiftRegister(a, tile_latency+1, !z) + } + + (tile.io.out_a, tile.io.out_a_zero) + } + } - // Chain tile_out_b -> tile_b_in (pipeline b across each column) + // Chain tile_out_b -> tile_b_in (pipeline b across each column) + /* for (c <- 0 until meshColumns) { meshT(c).foldLeft((io.in_b(c), io.in_valid(c))) { case ((in_b, valid), tile) => @@ -63,8 +97,24 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, (tile.io.out_b, tile.io.out_valid) } } + */ + // for zero-gating + for(c <- 0 until meshColumns){ + meshT(c).foldLeft((io.in_b(c), io.in_valid(c), in_b_zero(c))) { + case ((in_b, valid, bz), tile) => + // tile.io.in_b := RegEnable(in_b, valid.head) + + (tile.io.in_b, in_b, bz).zipped.foreach { case (inb, b, z) => + //inb := RegEnable(b, valid.head && !z) + inb := pipe(valid.head && !z, b, tile_latency + 1) + } + + (tile.io.out_b, tile.io.out_valid, tile.io.out_b_zero) + } + } // Chain tile_out -> tile_propag (pipeline output across each column) + /* for (c <- 0 until meshColumns) { meshT(c).foldLeft((io.in_d(c), io.in_valid(c))) { case ((in_propag, valid), tile) => @@ -73,6 +123,24 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, } } + */ + + // for zero-gating + for(c <- 0 until meshColumns){ + meshT(c).foldLeft((io.in_d(c), io.in_valid(c), in_d_zero(c))) { + case ((in_propag, valid, outz), tile) => + // tile.io.in_d := RegEnable(in_propag, valid.head) + + (tile.io.in_d, in_propag, outz).zipped.foreach { case (ind, prop, z) => + //ind := RegEnable(prop, valid.head && !z) + ind := pipe(valid.head && !z, prop, tile_latency + 1) + } + + (tile.io.out_c, tile.io.out_valid, tile.io.out_c_zero) + } + } + + // Chain control signals (pipeline across each column) assert(!(mesh.map(_.map(_.io.bad_dataflow).reduce(_||_)).reduce(_||_))) for (c <- 0 until meshColumns) { @@ -114,13 +182,50 @@ class Mesh[T <: Data : Arithmetic](inputType: T, outputType: T, accType: T, } } + // added for zero-gating + // Chain a_zero (pipeline across each column) + for (r <- 0 until meshRows) { + mesh(r).foldLeft(in_a_zero(r)) { + case (in_z, tile) => + tile.io.in_a_zero := ShiftRegister(in_z, tile_latency+1) + tile.io.out_a_zero + } + } + + // Chain b_zero (pipeline across each column) + for (c <- 0 until meshColumns) { + meshT(c).foldLeft(in_b_zero(c)) { + case (in_z, tile) => + tile.io.in_b_zero := ShiftRegister(in_z, tile_latency+1) + tile.io.out_b_zero + } + } + + // Chain d_zero (pipeline across each column) + for (c <- 0 until meshColumns) { + meshT(c).foldLeft(in_d_zero(c)) { + case (in_z, tile) => + tile.io.in_d_zero := ShiftRegister(in_z, tile_latency+1) + tile.io.out_c_zero + } + } + // Capture out_vec and out_control_vec (connect IO to bottom row of mesh) // (The only reason we have so many zips is because Scala doesn't provide a zipped function for Tuple4) for (((((((b, c), v), ctrl), id), last), tile) <- io.out_b zip io.out_c zip io.out_valid zip io.out_control zip io.out_id zip io.out_last zip mesh.last) { // TODO we pipelined this to make physical design easier. Consider removing these if possible // TODO shouldn't we clock-gate these signals with "garbage" as well? - b := ShiftRegister(tile.io.out_b, output_delay) - c := ShiftRegister(tile.io.out_c, output_delay) + //b := ShiftRegister(tile.io.out_b, output_delay) + //c := ShiftRegister(tile.io.out_c, output_delay) + + //added for zero-gating + b := ShiftRegister(VecInit(tile.io.out_b.zip(tile.io.out_b_zero).map { case (outb, outbz) => + Mux(outbz, 0.U.asTypeOf(outb), outb) + }), output_delay) + c := ShiftRegister(VecInit(tile.io.out_c.zip(tile.io.out_c_zero).map { case (outc, outcz) => + Mux(outcz, 0.U.asTypeOf(outc), outc) + }), output_delay) + v := ShiftRegister(tile.io.out_valid, output_delay) ctrl := ShiftRegister(tile.io.out_control, output_delay) id := ShiftRegister(tile.io.out_id, output_delay) diff --git a/src/main/scala/gemmini/PE.scala b/src/main/scala/gemmini/PE.scala index 6e065125..0804285a 100644 --- a/src/main/scala/gemmini/PE.scala +++ b/src/main/scala/gemmini/PE.scala @@ -53,6 +53,15 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, val out_valid = Output(Bool()) val bad_dataflow = Output(Bool()) + + // added for zero-gating + + val in_a_zero = Input(Bool()) + val in_b_zero = Input(Bool()) + val in_d_zero = Input(Bool()) + val out_a_zero = Output(Bool()) + val out_b_zero = Output(Bool()) + val out_c_zero = Output(Bool()) }) val cType = if (df == Dataflow.WS) inputType else accType @@ -76,6 +85,14 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, val last = io.in_last val valid = io.in_valid + // added for zero-gating + val a_zero = io.in_a_zero + val b_zero = io.in_b_zero + val d_zero = io.in_d_zero + val c1_zero = Reg(Bool()) + val c2_zero = Reg(Bool()) + + io.out_a := a io.out_control.dataflow := dataflow io.out_control.propagate := prop @@ -84,6 +101,10 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, io.out_last := last io.out_valid := valid + // for zero-gating + io.out_a_zero := a_zero + io.out_c_zero := io.out_c.asUInt === 0.U + mac_unit.io.in_a := a val last_s = RegEnable(prop, valid) @@ -115,25 +136,45 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, c1 := mac_unit.io.out_d c2 := d.withWidthOf(cType) } + io.out_b_zero := b_zero }.elsewhen ((df == Dataflow.WS).B || ((df == Dataflow.BOTH).B && dataflow === WEIGHT_STATIONARY)) { when(prop === PROPAGATE) { io.out_c := c1 mac_unit.io.in_b := c2.asTypeOf(inputType) - mac_unit.io.in_c := b - io.out_b := mac_unit.io.out_d - c1 := d + //mac_unit.io.in_c := b + //io.out_b := mac_unit.io.out_d + //c1 := d + //added for zero-gating + //io.out_b := Mux(c2_zero, b, Mux(b_zero, 0.U.asTypeOf(b), b).mac(Mux(a_zero, 0.U.asTypeOf(a), a), c2.asTypeOf(inputType))) + mac_unit.io.in_a := Mux(a_zero, 0.U.asTypeOf(a), a) + mac_unit.io.in_c := Mux(b_zero, 0.U.asTypeOf(b), b) + io.out_b := Mux(c2_zero, b, mac_unit.io.out_d) + c1 := Mux(d_zero, 0.U.asTypeOf(d), d) + c1_zero := d_zero + io.out_b_zero := Mux(c2_zero, b_zero, io.out_b.asUInt === 0.U) }.otherwise { io.out_c := c2 mac_unit.io.in_b := c1.asTypeOf(inputType) - mac_unit.io.in_c := b - io.out_b := mac_unit.io.out_d - c2 := d + //mac_unit.io.in_c := b + //io.out_b := mac_unit.io.out_d + //c2 := d + //added for zero-gating + //io.out_b := Mux(c1_zero, b, Mux(b_zero, 0.U.asTypeOf(b), b).mac(Mux(a_zero, 0.U.asTypeOf(a), a), c1.asTypeOf(inputType))) + mac_unit.io.in_a := Mux(a_zero, 0.U.asTypeOf(a), a) + mac_unit.io.in_c := Mux(b_zero, 0.U.asTypeOf(b), b) + io.out_b := Mux(c1_zero, b, mac_unit.io.out_d) + c2 := Mux(d_zero, 0.U.asTypeOf(d), d) + c2_zero := d_zero + io.out_b_zero := Mux(c1_zero, b_zero, io.out_b.asUInt === 0.U) } }.otherwise { io.bad_dataflow := true.B //assert(false.B, "unknown dataflow") io.out_c := DontCare io.out_b := DontCare + // add for zero-gating + io.out_b_zero := DontCare + io.out_c_zero := DontCare mac_unit.io.in_b := b.asTypeOf(inputType) mac_unit.io.in_c := c2 } @@ -143,5 +184,20 @@ class PE[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Value, c2 := c2 mac_unit.io.in_b := DontCare mac_unit.io.in_c := DontCare + //addd for zero-gating + c1_zero := c1_zero + c2_zero := c2_zero } + /* + when (io.in_a_zero || io.in_b_zero) { + c1 := c1 + c2 := c2 + mac_unit.io.in_b := DontCare + mac_unit.io.in_c := DontCare + //addd for zero-gating + //c1_zero := c1_zero + //c2_zero := c2_zero + } + + */ } diff --git a/src/main/scala/gemmini/Tile.scala b/src/main/scala/gemmini/Tile.scala index 9c2a418c..54f537a7 100644 --- a/src/main/scala/gemmini/Tile.scala +++ b/src/main/scala/gemmini/Tile.scala @@ -35,6 +35,14 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu val out_valid = Output(Vec(columns, Bool())) val bad_dataflow = Output(Bool()) + + // added for zero-gating + val in_a_zero = Input(Vec(rows, Bool())) + val in_b_zero = Input(Vec(columns, Bool())) + val in_d_zero = Input(Vec(columns, Bool())) + val out_a_zero = Output(Vec(rows, Bool())) + val out_b_zero = Output(Vec(columns, Bool())) + val out_c_zero = Output(Vec(columns, Bool())) }) import ev._ @@ -106,6 +114,35 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu } } + // added for zero-gating + // Broadcast 'a_zero' horizontally across the Tile + for (r <- 0 until columns) { + tile(r).foldLeft(io.in_a_zero(r)) { + case (z, pe) => + pe.io.in_a_zero := z + pe.io.out_a_zero + } + } + + // Broadcast 'b_zero' vertically across the Tile + for (c <- 0 until columns) { + tileT(c).foldLeft(io.in_b_zero(c)) { + case (z, pe) => + pe.io.in_b_zero := z + pe.io.out_b_zero + } + } + + // Broadcast 'd_zero' vertically across the Tile + for (c <- 0 until columns) { + tileT(c).foldLeft(io.in_d_zero(c)) { + case (z, pe) => + pe.io.in_d_zero := z + pe.io.out_c_zero + } + } + + // Drive the Tile's bottom IO for (c <- 0 until columns) { io.out_c(c) := tile(rows-1)(c).io.out_c @@ -114,6 +151,10 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu io.out_last(c) := tile(rows-1)(c).io.out_last io.out_valid(c) := tile(rows-1)(c).io.out_valid + //added for zero-gating + io.out_b_zero(c) := tile(rows-1)(c).io.out_b_zero + io.out_c_zero(c) := tile(rows-1)(c).io.out_c_zero + io.out_b(c) := { if (tree_reduction) { val prods = tileT(c).map(_.io.out_b) @@ -128,5 +169,7 @@ class Tile[T <: Data](inputType: T, outputType: T, accType: T, df: Dataflow.Valu // Drive the Tile's right IO for (r <- 0 until rows) { io.out_a(r) := tile(r)(columns-1).io.out_a + //added for zero-gating + io.out_a_zero(r) := tile(r)(columns-1).io.out_a_zero } }