Skip to content

Commit

Permalink
single rw tilelink node for external scratchpad
Browse files Browse the repository at this point in the history
  • Loading branch information
richardyrh committed Jan 29, 2024
1 parent 81aee08 commit abfdff4
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 39 deletions.
165 changes: 143 additions & 22 deletions src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import chisel3.util._
import org.chipsalliance.cde.config._
import freechips.rocketchip.diplomacy._
import freechips.rocketchip.tile._
import freechips.rocketchip.util.ClockGate
import freechips.rocketchip.tilelink.{TLBundle, TLClientNode, TLEdgeOut, TLFragmenter, TLIdentityNode, TLManagerNode, TLMasterParameters, TLMasterPortParameters, TLMasterToSlaveTransferSizes, TLRAM, TLRegisterNode, TLSlaveParameters, TLSlavePortParameters, TLWidthWidget, TLXbar}
import freechips.rocketchip.util.{BundleField, ClockGate}
import freechips.rocketchip.tilelink._
import GemminiISA._
import Util._
import freechips.rocketchip.regmapper.RegField
Expand Down Expand Up @@ -63,6 +63,67 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
val acc_data_len = config.sp_width / config.inputType.getWidth * config.accType.getWidth / 8
val max_data_len = spad_data_len // max acc_data_len

val mem_depth = config.sp_bank_entries * spad_data_len / max_data_len
val mem_width = max_data_len

// this node accepts both read and write requests,
// splits & arbitrates them into one client node per type of operation
val unified_mem_node = TLNexusNode(
clientFn = { seq =>
val in_mapping = TLXbar.mapInputIds(seq)
val read_src_range = IdRange(in_mapping.map(_.start).min, in_mapping.map(_.end).max)
assert((read_src_range.start == 0) && isPow2(read_src_range.end))
val write_src_range = read_src_range.shift(read_src_range.size)

seq(0).v1copy(
echoFields = BundleField.union(seq.flatMap(_.echoFields)),
requestFields = BundleField.union(seq.flatMap(_.requestFields)),
responseKeys = seq.flatMap(_.responseKeys).distinct,
minLatency = seq.map(_.minLatency).min,
clients = Seq(
TLMasterParameters.v1(
name = "unified_mem_read_client",
sourceId = read_src_range,
supportsProbe = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsGet = TransferSizes.mincover(seq.map(_.anyEmitClaims.get)),
supportsPutFull = TransferSizes.none,
supportsPutPartial = TransferSizes.none
),
TLMasterParameters.v1(
name = "unified_mem_write_client",
sourceId = write_src_range,
supportsProbe = TransferSizes.mincover(
seq.map(_.anyEmitClaims.putFull) ++seq.map(_.anyEmitClaims.putPartial)),
supportsGet = TransferSizes.none,
supportsPutFull = TransferSizes.mincover(seq.map(_.anyEmitClaims.putFull)),
supportsPutPartial = TransferSizes.mincover(seq.map(_.anyEmitClaims.putPartial))
)
)
)
},
managerFn = { seq =>
// val fifoIdFactory = TLXbar.relabeler()
seq(0).v1copy(
responseFields = BundleField.union(seq.flatMap(_.responseFields)),
requestKeys = seq.flatMap(_.requestKeys).distinct,
minLatency = seq.map(_.minLatency).min,
endSinkId = TLXbar.mapOutputIds(seq).map(_.end).max,
managers = Seq(TLSlaveParameters.v2(
name = Some(f"unified_mem_manager"),
address = Seq(AddressSet(spad_base, mem_depth * mem_width * config.sp_banks - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, mem_width),
putFull = TransferSizes(1, mem_width),
putPartial = TransferSizes(1, mem_width)),
fifoId = Some(0)
))
)
}
)

unified_mem_read_node := TLWidthWidget(spad_data_len) := unified_mem_node
unified_mem_write_node := TLWidthWidget(spad_data_len) := unified_mem_node

val spad_tl_ram : Seq[Seq[TLManagerNode]] = if (config.use_shared_ext_mem && config.use_tl_ext_mem) {
unified_mem_read_node :=* TLWidthWidget(spad_data_len) :=* spad_read_nodes
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
Expand All @@ -87,9 +148,6 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
}.map(x => Seq(x.node))
} else {
(0 until config.sp_banks).map { bank =>
val mem_depth = config.sp_bank_entries * spad_data_len / max_data_len
val mem_width = max_data_len

Seq(TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_read_mgr"),
Expand All @@ -116,22 +174,15 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
}
}

require(!config.sp_singleported)
if (config.sp_singleported) {
val xbar = TLXbar()
xbar :=* unified_mem_read_node
xbar :=* unified_mem_write_node
banks.foreach(_.head := xbar)
} else {
val r_xbar = TLXbar()
val w_xbar = TLXbar()
r_xbar :=* unified_mem_read_node
w_xbar :=* unified_mem_write_node
banks.foreach { mem =>
require(mem.length == 2)
mem.head := r_xbar
mem.last := TLFragmenter(spad_data_len, spad.maxBytes) := w_xbar
}
require(!config.sp_singleported, "external scratchpad must be dual ported")
val r_xbar = TLXbar()
val w_xbar = TLXbar()
r_xbar :=* unified_mem_read_node
w_xbar :=* unified_mem_write_node
banks.foreach { mem =>
require(mem.length == 2)
mem.head := r_xbar
mem.last := TLFragmenter(spad_data_len, spad.maxBytes) := w_xbar
}

banks
Expand All @@ -152,8 +203,8 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA

regNode := TLFragmenter(8, 64) := stlNode


unified_mem_write_node := spad.spad_writer.node

}

class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
Expand Down Expand Up @@ -276,6 +327,76 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
// connect(ext_mem_acc(i)(0), log2Up(outer.acc_data_len),
// r_node, r_edge, source_counters(2), w_node, w_edge, source_counters(3))
// }

// hook up read/write for general unified mem nodes
{
val u_out = outer.unified_mem_node.out
val u_in = outer.unified_mem_node.in
assert(u_out.length == 2)
println(f"gemmini unified memory node has ${u_in.length} incoming client(s)")

val r_out = u_out.head
val w_out = u_out.last

val in_src = TLXbar.mapInputIds(u_in.map(_._2.client))
val in_src_size = in_src.map(_.end).max
assert(isPow2(in_src_size)) // should be checked already, but just to be sure

// arbitrate all reads into one read while assigning source prefix, same for write
val a_arbiter_in = (u_in zip in_src).map { case ((in_node, _), src_range) =>
val in_r: DecoupledIO[TLBundleA] =
WireDefault(0.U.asTypeOf(Decoupled(new TLBundleA(in_node.a.bits.params.copy(
sourceBits = log2Up(in_src_size) + 1
)))))
val in_w: DecoupledIO[TLBundleA] = WireDefault(0.U.asTypeOf(in_r.cloneType))

val req_is_read = in_node.a.bits.opcode === TLMessages.Get

(Seq(in_r.bits.user, in_r.bits.opcode, in_r.bits.size, in_r.bits.data) zip
Seq(in_node.a.bits.user, in_node.a.bits.opcode, in_node.a.bits.size, in_node.a.bits.data))
.foreach { case (x, y) => x := y }
in_r.bits.source := in_node.a.bits.source | src_range.start.U | Mux(req_is_read, 0.U, in_src_size.U)
in_w.bits := in_r.bits

in_r.valid := in_node.a.valid && req_is_read
in_w.valid := in_node.a.valid && !req_is_read
in_node.a.ready := in_r.ready && in_w.ready // TODO(richard): could be a mux

(in_r, in_w)
}
// we cannot use round robin because it might reorder requests, even from the same client
val (a_arbiter_in_r_nodes, a_arbiter_in_w_nodes) = a_arbiter_in.unzip
TLArbiter.lowest(r_out._2, r_out._1.a, a_arbiter_in_r_nodes:_*)
TLArbiter.lowest(w_out._2, w_out._1.a, a_arbiter_in_w_nodes:_*)

// for each unified mem node client, arbitrate read/write responses on d channel
(u_in zip in_src).zipWithIndex.foreach { case (((in_node, in_edge), src_range), i) =>
// assign d channel back based on source, invalid if source prefix mismatch
val resp = Seq(r_out._1.d, w_out._1.d)
val source_match = resp.map(r => src_range.contains(r.bits.source))
val d_arbiter_in = resp.map(r => WireDefault(
0.U.asTypeOf(Decoupled(new TLBundleD(r.bits.params.copy(
sourceBits = in_node.d.bits.source.getWidth,
sizeBits = in_node.d.bits.size.getWidth
))))
))
def trim(id: UInt, size: Int): UInt = if (size <= 1) 0.U else id(log2Ceil(size)-1, 0) // from Xbar

(d_arbiter_in lazyZip resp lazyZip source_match).foreach { case (arb_in, r, sm) =>
(Seq(arb_in.bits.user, arb_in.bits.opcode, arb_in.bits.data) zip
Seq(r.bits.user, r.bits.opcode, r.bits.data)).foreach { case (x, y) => x := y }
arb_in.bits.source := trim(r.bits.source, in_node.d.bits.source.getWidth) // we can trim b/c isPow2(prefix)
arb_in.bits.size := trim(r.bits.size, in_node.d.bits.size.getWidth) // FIXME: check truncation

arb_in.valid := r.valid && sm
r.ready := arb_in.ready
}

TLArbiter.robin(in_edge, in_node.d, d_arbiter_in:_*)
}

}

} else if (use_shared_ext_mem) {
ext_mem_io.foreach(_ <> outer.spad.module.io.ext_mem.get)
}
Expand Down
34 changes: 17 additions & 17 deletions src/main/scala/gemmini/StoreController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,34 @@ class StoreController[T <: Data : Arithmetic, U <: Data, V <: Data](config: Gemm

val control_state = RegInit(waiting_for_command)

val stride = Reg(UInt(coreMaxAddrBits.W))
val stride = RegInit(0.U(coreMaxAddrBits.W))
val block_rows = meshRows * tileRows
val block_stride = block_rows.U
val block_cols = meshColumns * tileColumns
val max_blocks = (dma_maxbytes / (block_cols * inputType.getWidth / 8)) max 1

val activation = Reg(UInt(Activation.bitwidth.W)) // TODO magic number
val igelu_qb = Reg(accType)
val igelu_qc = Reg(accType)
val iexp_qln2 = Reg(accType)
val iexp_qln2_inv = Reg(accType)
val norm_stats_id = Reg(UInt(8.W)) // TODO magic number
val acc_scale = Reg(acc_scale_t)
val activation = RegInit(0.U(Activation.bitwidth.W)) // TODO magic number
val igelu_qb = RegInit(0.U.asTypeOf(accType))
val igelu_qc = RegInit(0.U.asTypeOf(accType))
val iexp_qln2 = RegInit(0.U.asTypeOf(accType))
val iexp_qln2_inv = RegInit(0.U.asTypeOf(accType))
val norm_stats_id = RegInit(0.U(8.W)) // TODO magic number
val acc_scale = RegInit(0.U.asTypeOf(acc_scale_t))

//val row_counter = RegInit(0.U(log2Ceil(block_rows).W))
val row_counter = RegInit(0.U(12.W)) // TODO magic number
val block_counter = RegInit(0.U(8.W)) // TODO magic number

// Pooling variables
val pool_stride = Reg(UInt(CONFIG_MVOUT_RS1_MAX_POOLING_STRIDE_WIDTH.W)) // When this is 0, pooling is disabled
val pool_size = Reg(UInt(CONFIG_MVOUT_RS1_MAX_POOLING_WINDOW_SIZE_WIDTH.W))
val pool_out_dim = Reg(UInt(CONFIG_MVOUT_RS1_POOL_OUT_DIM_WIDTH.W))
val pool_porows = Reg(UInt(CONFIG_MVOUT_RS1_POOL_OUT_ROWS_WIDTH.W))
val pool_pocols = Reg(UInt(CONFIG_MVOUT_RS1_POOL_OUT_COLS_WIDTH.W))
val pool_orows = Reg(UInt(CONFIG_MVOUT_RS1_OUT_ROWS_WIDTH.W))
val pool_ocols = Reg(UInt(CONFIG_MVOUT_RS1_OUT_COLS_WIDTH.W))
val pool_upad = Reg(UInt(CONFIG_MVOUT_RS1_UPPER_ZERO_PADDING_WIDTH.W))
val pool_lpad = Reg(UInt(CONFIG_MVOUT_RS1_LEFT_ZERO_PADDING_WIDTH.W))
val pool_stride = RegInit(0.U(CONFIG_MVOUT_RS1_MAX_POOLING_STRIDE_WIDTH.W)) // When this is 0, pooling is disabled
val pool_size = RegInit(0.U(CONFIG_MVOUT_RS1_MAX_POOLING_WINDOW_SIZE_WIDTH.W))
val pool_out_dim = RegInit(0.U(CONFIG_MVOUT_RS1_POOL_OUT_DIM_WIDTH.W))
val pool_porows = RegInit(0.U(CONFIG_MVOUT_RS1_POOL_OUT_ROWS_WIDTH.W))
val pool_pocols = RegInit(0.U(CONFIG_MVOUT_RS1_POOL_OUT_COLS_WIDTH.W))
val pool_orows = RegInit(0.U(CONFIG_MVOUT_RS1_OUT_ROWS_WIDTH.W))
val pool_ocols = RegInit(0.U(CONFIG_MVOUT_RS1_OUT_COLS_WIDTH.W))
val pool_upad = RegInit(0.U(CONFIG_MVOUT_RS1_UPPER_ZERO_PADDING_WIDTH.W))
val pool_lpad = RegInit(0.U(CONFIG_MVOUT_RS1_LEFT_ZERO_PADDING_WIDTH.W))

val porow_counter = RegInit(0.U(pool_porows.getWidth.W))
val pocol_counter = RegInit(0.U(pool_pocols.getWidth.W))
Expand Down

0 comments on commit abfdff4

Please sign in to comment.