Skip to content

Commit

Permalink
back-to-back spad access and other mem node fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
richardyrh committed Feb 24, 2024
1 parent 19a12b2 commit c97cef9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
55 changes: 39 additions & 16 deletions src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
val create_tl_mem = config.use_shared_ext_mem && config.use_tl_ext_mem

val num_ids = 32 // TODO (richard): move to config
val spad_base = x"ff000000"
val spad_base = config.tl_ext_mem_base

val unified_mem_read_node = TLIdentityNode()
val spad_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) {i =>
Expand All @@ -65,6 +65,8 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA

val mem_depth = config.sp_bank_entries * spad_data_len / max_data_len
val mem_width = max_data_len
require(mem_depth * mem_width * config.sp_banks == 1 << 14, f"memory size is ${mem_depth}, ${mem_width}")
println(f"unified shared memory size: ${mem_depth}x${mem_width}x${config.sp_banks}")

// this node accepts both read and write requests,
// splits & arbitrates them into one client node per type of operation
Expand Down Expand Up @@ -196,10 +198,10 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA

val regDevice = new SimpleDevice("gemmini-cmd-reg", Seq(s"gemmini-cmd-reg"))
val regNode = TLRegisterNode(
address = Seq(AddressSet(0xff100000L, 0xfff)),
address = Seq(AddressSet(0xff007000L, 0xfff)),
device = regDevice,
beatBytes = 8,
concurrency = 1)
concurrency = 0)

regNode := TLFragmenter(8, 64) := stlNode

Expand All @@ -226,11 +228,11 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
val source_counters = Seq.fill(4)(Counter(outer.num_ids))

if (outer.create_tl_mem) {
def connect(ext_mem: ExtMemIO, req_size: Int, r_node: TLBundle, r_edge: TLEdgeOut, r_source: Counter,
def connect(ext_mem: ExtMemIO, bank_base: Int, req_size: Int, r_node: TLBundle, r_edge: TLEdgeOut, r_source: Counter,
w_node: TLBundle, w_edge: TLEdgeOut, w_source: Counter): Unit = {
r_node.a.valid := ext_mem.read_req.valid
r_node.a.bits := r_edge.Get(r_source.value,
(ext_mem.read_req.bits << req_size.U).asUInt | outer.spad_base.U,
(ext_mem.read_req.bits << req_size.U).asUInt | bank_base.U | outer.spad_base.U,
req_size.U)._2
ext_mem.read_req.ready := r_node.a.ready

Expand All @@ -239,7 +241,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]

w_node.a.valid := ext_mem.write_req.valid
w_node.a.bits := w_edge.Put(w_source.value,
w_shifted_addr | outer.spad_base.U,
w_shifted_addr | bank_base.U | outer.spad_base.U,
req_size.U, ext_mem.write_req.bits.data, w_mask)._2
ext_mem.write_req.ready := w_node.a.ready

Expand All @@ -254,7 +256,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
}
(outer.spad_read_nodes.out zip outer.spad_write_nodes.out)
.zipWithIndex.foreach{ case (((r_node, r_edge), (w_node, w_edge)), i) =>
connect(ext_mem_spad(i), log2Up(outer.spad_data_len),
connect(ext_mem_spad(i), i * outer.mem_depth * outer.mem_width, log2Up(outer.spad_data_len),
r_node, r_edge, source_counters(0), w_node, w_edge, source_counters(1))
}

Expand All @@ -273,7 +275,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]

// READ
mem.io.ren := r_node.a.fire
mem.io.raddr := r_node.a.bits.address ^ outer.spad_base.U
mem.io.raddr := (r_node.a.bits.address ^ outer.spad_base.U) >> log2Ceil(mem_width).U

val data_pipe_in = Wire(DecoupledIO(mem.io.rdata.cloneType))
data_pipe_in.valid := RegNext(mem.io.ren)
Expand All @@ -287,27 +289,48 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
metadata_pipe_in.bits.source := r_node.a.bits.source
metadata_pipe_in.bits.size := r_node.a.bits.size

val sram_read_backup_reg = RegInit(0.U.asTypeOf(Valid(mem.io.rdata.cloneType)))

val data_pipe_inst = Module(new Pipeline(data_pipe_in.bits.cloneType, 1)())
data_pipe_inst.io.in <> data_pipe_in
val data_pipe = data_pipe_inst.io.out
val metadata_pipe = Pipeline(metadata_pipe_in, 2)
assert(data_pipe_in.ready || !data_pipe_in.valid)
assert(metadata_pipe_in.ready || !data_pipe_in.ready)
assert(data_pipe.valid === metadata_pipe.valid)
assert((data_pipe.valid || sram_read_backup_reg.valid) === metadata_pipe.valid)

// data pipe is filled, but D is not ready and SRAM read came back
when (data_pipe.valid && !r_node.d.ready && data_pipe_in.valid) {
assert(!data_pipe_in.ready) // we should fill backup reg only if data pipe is not enqueueing
assert(!sram_read_backup_reg.valid) // backup reg should be empty
assert(!metadata_pipe_in.ready) // metadata should be filled previous cycle
sram_read_backup_reg.valid := true.B
sram_read_backup_reg.bits := mem.io.rdata
}.otherwise {
assert(data_pipe_in.ready || !data_pipe_in.valid) // do not skip any response
}

assert(metadata_pipe_in.fire || !mem.io.ren) // when requesting sram, metadata needs to be ready
assert(r_node.d.fire === metadata_pipe.fire) // metadata dequeues iff D fires

// when D becomes ready, and data pipe has emptied, time for backup to empty
when (r_node.d.ready && sram_read_backup_reg.valid && !data_pipe.valid) {
sram_read_backup_reg.valid := false.B
}
assert(!(sram_read_backup_reg.valid && data_pipe.valid && data_pipe_in.fire)) // must empty backup before filling data pipe
assert(data_pipe_in.valid === data_pipe_in.fire)

r_node.d.bits := r_edge.AccessAck(
metadata_pipe.bits.source,
metadata_pipe.bits.size,
data_pipe.bits)
r_node.d.valid := data_pipe.valid
// take new requests only we have the buffer slot open in case downstream becomes unready
r_node.a.ready := r_node.d.ready && !data_pipe_inst.io.busy
Mux(!data_pipe.valid, sram_read_backup_reg.bits, data_pipe.bits))
r_node.d.valid := data_pipe.valid || sram_read_backup_reg.valid
// r node A is not ready only if D is not ready and both slots filled
r_node.a.ready := r_node.d.ready && !(data_pipe.valid && sram_read_backup_reg.valid)
data_pipe.ready := r_node.d.ready
metadata_pipe.ready := r_node.d.ready

// WRITE
mem.io.wen := w_node.a.fire
mem.io.waddr := w_node.a.bits.address ^ outer.spad_base.U
mem.io.waddr := (w_node.a.bits.address ^ outer.spad_base.U) >> log2Ceil(mem_width).U
mem.io.wdata := w_node.a.bits.data
mem.io.mask := w_node.a.bits.mask.asBools
w_node.a.ready := w_node.d.ready// && (mem.io.waddr =/= mem.io.raddr)
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/gemmini/GemminiConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](

use_shared_ext_mem: Boolean = false,
use_tl_ext_mem: Boolean = false,
tl_ext_mem_base: BigInt = 0,
clock_gate: Boolean = false,

headerFileName: String = "gemmini_params.h"
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/gemmini/Scratchpad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,

spad_writer.module.io.req.valid := write_issue_q.io.deq.valid && writeData.valid && write_issue_q.io.deq.bits.dest.asBool
write_issue_q.io.deq.ready := writer.module.io.req.ready && spad_writer.module.io.req.ready && writeData.valid
spad_writer.module.io.req.bits.vaddr := write_issue_q.io.deq.bits.vaddr << 4.U // TODO(richard): do not hardcode
spad_writer.module.io.req.bits.vaddr := config.tl_ext_mem_base.U |
(write_issue_q.io.deq.bits.vaddr.asUInt << log2Ceil(config.DIM * config.inputType.getWidth / 8).U).asUInt
spad_writer.module.io.req.bits.physical := write_issue_q.io.deq.bits.dest
spad_writer.module.io.req.bits.len := Mux(writeData_is_full_width,
write_issue_q.io.deq.bits.len * (accType.getWidth / 8).U,
Expand Down

0 comments on commit c97cef9

Please sign in to comment.