diff --git a/src/main/scala/gemmini/Controller.scala b/src/main/scala/gemmini/Controller.scala index 0cc6ac6e..fd92516c 100644 --- a/src/main/scala/gemmini/Controller.scala +++ b/src/main/scala/gemmini/Controller.scala @@ -38,7 +38,7 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA val create_tl_mem = config.use_shared_ext_mem && config.use_tl_ext_mem val num_ids = 32 // TODO (richard): move to config - val spad_base = x"ff000000" + val spad_base = config.tl_ext_mem_base val unified_mem_read_node = TLIdentityNode() val spad_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) {i => @@ -65,6 +65,8 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA val mem_depth = config.sp_bank_entries * spad_data_len / max_data_len val mem_width = max_data_len + require(mem_depth * mem_width * config.sp_banks == 1 << 14, f"memory size is ${mem_depth}, ${mem_width}") + println(f"unified shared memory size: ${mem_depth}x${mem_width}x${config.sp_banks}") // this node accepts both read and write requests, // splits & arbitrates them into one client node per type of operation @@ -196,10 +198,10 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA val regDevice = new SimpleDevice("gemmini-cmd-reg", Seq(s"gemmini-cmd-reg")) val regNode = TLRegisterNode( - address = Seq(AddressSet(0xff100000L, 0xfff)), + address = Seq(AddressSet(0xff007000L, 0xfff)), device = regDevice, beatBytes = 8, - concurrency = 1) + concurrency = 0) regNode := TLFragmenter(8, 64) := stlNode @@ -226,11 +228,11 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] val source_counters = Seq.fill(4)(Counter(outer.num_ids)) if (outer.create_tl_mem) { - def connect(ext_mem: ExtMemIO, req_size: Int, r_node: TLBundle, r_edge: TLEdgeOut, r_source: Counter, + def connect(ext_mem: ExtMemIO, bank_base: Int, req_size: Int, r_node: TLBundle, r_edge: TLEdgeOut, r_source: Counter, w_node: TLBundle, w_edge: TLEdgeOut, w_source: Counter): Unit = { r_node.a.valid := ext_mem.read_req.valid r_node.a.bits := r_edge.Get(r_source.value, - (ext_mem.read_req.bits << req_size.U).asUInt | outer.spad_base.U, + (ext_mem.read_req.bits << req_size.U).asUInt | bank_base.U | outer.spad_base.U, req_size.U)._2 ext_mem.read_req.ready := r_node.a.ready @@ -239,7 +241,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] w_node.a.valid := ext_mem.write_req.valid w_node.a.bits := w_edge.Put(w_source.value, - w_shifted_addr | outer.spad_base.U, + w_shifted_addr | bank_base.U | outer.spad_base.U, req_size.U, ext_mem.write_req.bits.data, w_mask)._2 ext_mem.write_req.ready := w_node.a.ready @@ -254,7 +256,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] } (outer.spad_read_nodes.out zip outer.spad_write_nodes.out) .zipWithIndex.foreach{ case (((r_node, r_edge), (w_node, w_edge)), i) => - connect(ext_mem_spad(i), log2Up(outer.spad_data_len), + connect(ext_mem_spad(i), i * outer.mem_depth * outer.mem_width, log2Up(outer.spad_data_len), r_node, r_edge, source_counters(0), w_node, w_edge, source_counters(1)) } @@ -273,7 +275,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] // READ mem.io.ren := r_node.a.fire - mem.io.raddr := r_node.a.bits.address ^ outer.spad_base.U + mem.io.raddr := (r_node.a.bits.address ^ outer.spad_base.U) >> log2Ceil(mem_width).U val data_pipe_in = Wire(DecoupledIO(mem.io.rdata.cloneType)) data_pipe_in.valid := RegNext(mem.io.ren) @@ -287,27 +289,48 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data] metadata_pipe_in.bits.source := r_node.a.bits.source metadata_pipe_in.bits.size := r_node.a.bits.size + val sram_read_backup_reg = RegInit(0.U.asTypeOf(Valid(mem.io.rdata.cloneType))) + val data_pipe_inst = Module(new Pipeline(data_pipe_in.bits.cloneType, 1)()) data_pipe_inst.io.in <> data_pipe_in val data_pipe = data_pipe_inst.io.out val metadata_pipe = Pipeline(metadata_pipe_in, 2) - assert(data_pipe_in.ready || !data_pipe_in.valid) - assert(metadata_pipe_in.ready || !data_pipe_in.ready) - assert(data_pipe.valid === metadata_pipe.valid) + assert((data_pipe.valid || sram_read_backup_reg.valid) === metadata_pipe.valid) + + // data pipe is filled, but D is not ready and SRAM read came back + when (data_pipe.valid && !r_node.d.ready && data_pipe_in.valid) { + assert(!data_pipe_in.ready) // we should fill backup reg only if data pipe is not enqueueing + assert(!sram_read_backup_reg.valid) // backup reg should be empty + assert(!metadata_pipe_in.ready) // metadata should be filled previous cycle + sram_read_backup_reg.valid := true.B + sram_read_backup_reg.bits := mem.io.rdata + }.otherwise { + assert(data_pipe_in.ready || !data_pipe_in.valid) // do not skip any response + } + + assert(metadata_pipe_in.fire || !mem.io.ren) // when requesting sram, metadata needs to be ready + assert(r_node.d.fire === metadata_pipe.fire) // metadata dequeues iff D fires + + // when D becomes ready, and data pipe has emptied, time for backup to empty + when (r_node.d.ready && sram_read_backup_reg.valid && !data_pipe.valid) { + sram_read_backup_reg.valid := false.B + } + assert(!(sram_read_backup_reg.valid && data_pipe.valid && data_pipe_in.fire)) // must empty backup before filling data pipe + assert(data_pipe_in.valid === data_pipe_in.fire) r_node.d.bits := r_edge.AccessAck( metadata_pipe.bits.source, metadata_pipe.bits.size, - data_pipe.bits) - r_node.d.valid := data_pipe.valid - // take new requests only we have the buffer slot open in case downstream becomes unready - r_node.a.ready := r_node.d.ready && !data_pipe_inst.io.busy + Mux(!data_pipe.valid, sram_read_backup_reg.bits, data_pipe.bits)) + r_node.d.valid := data_pipe.valid || sram_read_backup_reg.valid + // r node A is not ready only if D is not ready and both slots filled + r_node.a.ready := r_node.d.ready && !(data_pipe.valid && sram_read_backup_reg.valid) data_pipe.ready := r_node.d.ready metadata_pipe.ready := r_node.d.ready // WRITE mem.io.wen := w_node.a.fire - mem.io.waddr := w_node.a.bits.address ^ outer.spad_base.U + mem.io.waddr := (w_node.a.bits.address ^ outer.spad_base.U) >> log2Ceil(mem_width).U mem.io.wdata := w_node.a.bits.data mem.io.mask := w_node.a.bits.mask.asBools w_node.a.ready := w_node.d.ready// && (mem.io.waddr =/= mem.io.raddr) diff --git a/src/main/scala/gemmini/GemminiConfigs.scala b/src/main/scala/gemmini/GemminiConfigs.scala index 6873c683..a3aac838 100644 --- a/src/main/scala/gemmini/GemminiConfigs.scala +++ b/src/main/scala/gemmini/GemminiConfigs.scala @@ -93,6 +93,7 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data]( use_shared_ext_mem: Boolean = false, use_tl_ext_mem: Boolean = false, + tl_ext_mem_base: BigInt = 0, clock_gate: Boolean = false, headerFileName: String = "gemmini_params.h" diff --git a/src/main/scala/gemmini/Scratchpad.scala b/src/main/scala/gemmini/Scratchpad.scala index b67da1f9..3963e6a7 100644 --- a/src/main/scala/gemmini/Scratchpad.scala +++ b/src/main/scala/gemmini/Scratchpad.scala @@ -328,7 +328,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T, spad_writer.module.io.req.valid := write_issue_q.io.deq.valid && writeData.valid && write_issue_q.io.deq.bits.dest.asBool write_issue_q.io.deq.ready := writer.module.io.req.ready && spad_writer.module.io.req.ready && writeData.valid - spad_writer.module.io.req.bits.vaddr := write_issue_q.io.deq.bits.vaddr << 4.U // TODO(richard): do not hardcode + spad_writer.module.io.req.bits.vaddr := config.tl_ext_mem_base.U | + (write_issue_q.io.deq.bits.vaddr.asUInt << log2Ceil(config.DIM * config.inputType.getWidth / 8).U).asUInt spad_writer.module.io.req.bits.physical := write_issue_q.io.deq.bits.dest spad_writer.module.io.req.bits.len := Mux(writeData_is_full_width, write_issue_q.io.deq.bits.len * (accType.getWidth / 8).U,