Skip to content

Commit

Permalink
move scratchpad to outside with TL interface, add mvout_spad instruct…
Browse files Browse the repository at this point in the history
…ion that writes scaled results to spad
  • Loading branch information
richardyrh committed Jan 5, 2024
1 parent 709bc56 commit dffcfd1
Show file tree
Hide file tree
Showing 11 changed files with 515 additions and 128 deletions.
2 changes: 1 addition & 1 deletion software/gemmini-rocc-tests
80 changes: 65 additions & 15 deletions src/main/scala/gemmini/AccumulatorMem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AccPipeShared[T <: Data : Arithmetic](latency: Int, t: Vec[Vec[T]], banks:
class AccumulatorMem[T <: Data, U <: Data](
n: Int, t: Vec[Vec[T]], scale_func: (T, U) => T, scale_t: U,
acc_singleported: Boolean, acc_sub_banks: Int,
use_shared_ext_mem: Boolean,
use_shared_ext_mem: Boolean, use_tl_ext_ram: Boolean,
acc_latency: Int, acc_type: T, is_dummy: Boolean
)
(implicit ev: Arithmetic[T]) extends Module {
Expand Down Expand Up @@ -134,8 +134,46 @@ class AccumulatorMem[T <: Data, U <: Data](
val mask_len = t.getWidth / 8
val mask_elem = UInt((t.getWidth / mask_len).W)

// val ext_mem_write_q_enq = if (use_shared_ext_mem && use_tl_ext_ram) {
// require(acc_sub_banks == 1)
// Some(io.ext_mem.get.map { ext_mem =>
// val write_q = Module(new Queue(new Bundle {
// val write_addr = UInt()
// val write_data = UInt()
// val write_mask = UInt()
// }, 8, pipe = true, flow = true))

// write_q.io.enq.valid := false.B
// write_q.io.enq.bits := DontCare

// ext_mem.write_valid := write_q.io.deq.valid
// ext_mem.write_addr := write_q.io.deq.bits.write_addr
// ext_mem.write_data := write_q.io.deq.bits.write_data
// ext_mem.write_mask := write_q.io.deq.bits.write_mask
// write_q.io.deq.ready := ext_mem.write_ready
// write_q.io.enq
// })
// } else None
io.ext_mem.get.foreach(_.write_req.valid := false.B)
io.ext_mem.get.foreach(_.write_req.bits.addr := 0.U(io.write.bits.addr.getWidth.W))
io.ext_mem.get.foreach(_.write_req.bits.mask := 0.U(io.write.bits.mask.getWidth.W))
io.ext_mem.get.foreach(_.write_req.bits.data := 0.U(io.write.bits.data.getWidth.W))
io.ext_mem.get.foreach(_.read_req.bits := 0.U((mask_len * mask_elem.getWidth).W))
io.ext_mem.get.foreach(_.read_req.valid := false.B)
io.ext_mem.get.foreach(_.read_resp.ready := false.B) // no reading from external accmem
if (!acc_singleported && !is_dummy) {
require(!use_shared_ext_mem)
// if (use_shared_ext_mem && use_tl_ext_ram) {
// // duplicate write to external memory
// val enq = ext_mem_write_q_enq.get(0)
// enq.valid := oldest_pipelined_write.valid
// enq.bits.write_addr := oldest_pipelined_write.bits.addr
// enq.bits.write_data := Mux(oldest_pipelined_write.bits.acc, adder_sum.asUInt, oldest_pipelined_write.bits.data.asUInt)
// enq.bits.write_mask := oldest_pipelined_write.bits.mask.asUInt
// // TODO (richard): add buffer here and potentially propagate backpressure to systolic array
// assert(enq.ready || !enq.valid, "accumulator external memory write dropped")
// } else if (use_shared_ext_mem) {
// require(false, "cannot use two-port external acc mem bank")
// }
val mem = TwoPortSyncMem(n, t, mask_len) // TODO We assume byte-alignment here. Use aligned_to instead
mem.io.waddr := oldest_pipelined_write.bits.addr
mem.io.wen := oldest_pipelined_write.valid
Expand Down Expand Up @@ -163,27 +201,39 @@ class AccumulatorMem[T <: Data, U <: Data](
for (i <- 0 until acc_sub_banks) {
def isThisBank(addr: UInt) = addr(log2Ceil(acc_sub_banks)-1,0) === i.U
def getBankIdx(addr: UInt) = addr >> log2Ceil(acc_sub_banks)
val (read, write) = if (use_shared_ext_mem) {
val (read, write) = if (use_shared_ext_mem && !use_tl_ext_ram) {
def read(addr: UInt, ren: Bool): Data = {
io.ext_mem.get(i).read_en := ren
io.ext_mem.get(i).read_addr := addr
io.ext_mem.get(i).read_data
io.ext_mem.get(i).read_req.valid := ren
io.ext_mem.get(i).read_req.bits := addr
io.ext_mem.get(i).read_resp.bits
}
io.ext_mem.get(i).write_en := false.B
io.ext_mem.get(i).write_addr := DontCare
io.ext_mem.get(i).write_data := DontCare
io.ext_mem.get(i).write_mask := DontCare
io.ext_mem.get(i).write_req.bits := DontCare
def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = {
io.ext_mem.get(i).write_en := true.B
io.ext_mem.get(i).write_addr := addr
io.ext_mem.get(i).write_data := wdata.asUInt
io.ext_mem.get(i).write_mask := wmask.asUInt
io.ext_mem.get(i).write_req.valid := true.B
io.ext_mem.get(i).write_req.bits.addr := addr
io.ext_mem.get(i).write_req.bits.data := wdata.asUInt
io.ext_mem.get(i).write_req.bits.mask := wmask.asUInt
}
(read _, write _)
} else {
val mem = SyncReadMem(n / acc_sub_banks, Vec(mask_len, mask_elem))
io.ext_mem.get(i).read_req.bits := 0.U((mask_len * mask_elem.getWidth).W)
io.ext_mem.get(i).read_req.valid := false.B

def read(addr: UInt, ren: Bool): Data = mem.read(addr, ren)
def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = mem.write(addr, wdata, wmask)
def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = if (use_tl_ext_ram) {
mem.write(addr, wdata, wmask)
// duplicate write signal to external memory
// val enq = ext_mem_write_q_enq.get(i)
// enq.valid := true.B
// enq.bits.write_mask := wmask.asUInt
// enq.bits.write_addr := addr
// enq.bits.write_data := wdata.asUInt
// // TODO (richard): propagate backpressure to systolic array, add fence ability
// assert(enq.ready, "accumulator external memory write dropped")
} else {
mem.write(addr, wdata, wmask)
}
(read _, write _)
}

Expand Down
227 changes: 222 additions & 5 deletions src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ package gemmini

import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}

import chisel3._
import chisel3.util._
import org.chipsalliance.cde.config._
import freechips.rocketchip.diplomacy._
import freechips.rocketchip.tile._
import freechips.rocketchip.util.ClockGate
import freechips.rocketchip.tilelink.TLIdentityNode
import freechips.rocketchip.tilelink.{TLBundle, TLClientNode, TLEdgeOut, TLFragmenter, TLIdentityNode, TLManagerNode, TLMasterParameters, TLMasterPortParameters, TLMasterToSlaveTransferSizes, TLRAM, TLSlaveParameters, TLSlavePortParameters, TLWidthWidget, TLXbar}
import GemminiISA._
import Util._

Expand All @@ -35,11 +34,115 @@ class Gemmini[T <: Data : Arithmetic, U <: Data, V <: Data](val config: GemminiA
val xLen = p(XLen)
val spad = LazyModule(new Scratchpad(config))

val create_tl_mem = config.use_shared_ext_mem && config.use_tl_ext_mem

val num_ids = 32 // TODO (richard): move to config
val spad_base = 0 // 0x60000000L

val unified_mem_read_node = TLIdentityNode()
val spad_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) {i =>
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_read_node_$i", sourceId = IdRange(0, num_ids))))
}) else TLIdentityNode()
// val acc_read_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_read_node_$i", sourceId = IdRange(0, numIDs))))
// }) else TLIdentityNode()

val unified_mem_write_node = TLIdentityNode()
val spad_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.sp_banks) { i =>
TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_write_node_$i", sourceId = IdRange(0, num_ids))))
}) else TLIdentityNode()

// val spad_dma_write_node = TLClientNode(Seq(
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"spad_dma_write_node", sourceId = IdRange(0, num_ids))))))
// val acc_write_nodes = if (create_tl_mem) TLClientNode(Seq.tabulate(config.acc_banks) { i =>
// TLMasterPortParameters.v1(Seq(TLMasterParameters.v1(name = s"acc_write_node_$i", sourceId = IdRange(0, numIDs))))
// }) else TLIdentityNode()

val spad_data_len = config.sp_width / 8
val acc_data_len = config.sp_width / config.inputType.getWidth * config.accType.getWidth / 8
val max_data_len = spad_data_len // max acc_data_len

val spad_tl_ram : Seq[Seq[TLManagerNode]] = if (config.use_shared_ext_mem && config.use_tl_ext_mem) {
unified_mem_read_node :=* TLWidthWidget(spad_data_len) :=* spad_read_nodes
// unified_mem_read_node :=* TLWidthWidget(acc_data_len) :=* acc_read_nodes
unified_mem_write_node :=* TLWidthWidget(spad_data_len) :=* spad_write_nodes
// unified_mem_write_node :=* TLWidthWidget(acc_data_len) :=* acc_write_nodes

val stride_by_word = false // TODO (richard): move to config

require(isPow2(config.sp_banks))
val banks : Seq[Seq[TLManagerNode]] =
if (stride_by_word) {
assert(false, "TODO under construction")
assert((config.sp_capacity match { case CapacityInKilobytes(kb) => kb * 1024}) ==
config.sp_bank_entries * spad_data_len / max_data_len * config.sp_banks * max_data_len)
(0 until config.sp_banks).map { bank =>
LazyModule(new TLRAM(
address = AddressSet(max_data_len * bank,
((config.sp_bank_entries * spad_data_len / max_data_len - 1) * config.sp_banks + bank)
* max_data_len + (max_data_len - 1)),
beatBytes = max_data_len
))
}.map(x => Seq(x.node))
} else {
(0 until config.sp_banks).map { bank =>
val mem_depth = config.sp_bank_entries * spad_data_len / max_data_len
val mem_width = max_data_len

Seq(TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_read_mgr"),
address = Seq(AddressSet(spad_base + (mem_depth * mem_width * bank),
mem_depth * mem_width - 1)),
supports = TLMasterToSlaveTransferSizes(
get = TransferSizes(1, mem_width)),
fifoId = Some(0)
)),
beatBytes = mem_width
))),
TLManagerNode(Seq(TLSlavePortParameters.v1(
managers = Seq(TLSlaveParameters.v2(
name = Some(f"sp_bank${bank}_write_mgr"),
address = Seq(AddressSet(spad_base + (mem_depth * mem_width * bank),
mem_depth * mem_width - 1)),
supports = TLMasterToSlaveTransferSizes(
putFull = TransferSizes(1, mem_width),
putPartial = TransferSizes(1, mem_width)),
fifoId = Some(0)
)),
beatBytes = mem_width
))))
}
}

require(!config.sp_singleported)
if (config.sp_singleported) {
val xbar = TLXbar()
xbar :=* unified_mem_read_node
xbar :=* unified_mem_write_node
banks.foreach(_.head := xbar)
} else {
val r_xbar = TLXbar()
val w_xbar = TLXbar()
r_xbar :=* unified_mem_read_node
w_xbar :=* unified_mem_write_node
banks.foreach { mem =>
require(mem.length == 2)
mem.head := r_xbar
mem.last := TLFragmenter(spad_data_len, spad.maxBytes) := w_xbar
}
}

banks
} else Seq()

override lazy val module = new GemminiModule(this)
override val tlNode = if (config.use_dedicated_tl_port) spad.id_node else TLIdentityNode()
override val atlNode = if (config.use_dedicated_tl_port) TLIdentityNode() else spad.id_node

val node = if (config.use_dedicated_tl_port) tlNode else atlNode

unified_mem_write_node := spad.spad_writer.node
}

class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
Expand All @@ -50,8 +153,121 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
import outer.config._
import outer.spad

val ext_mem_io = if (use_shared_ext_mem) Some(IO(new ExtSpadMemIO(sp_banks, acc_banks, acc_sub_banks))) else None
ext_mem_io.foreach(_ <> outer.spad.module.io.ext_mem.get)
val ext_mem_io = if (use_shared_ext_mem && !use_tl_ext_mem)
Some(IO(new ExtSpadMemIO(sp_banks, acc_banks, acc_sub_banks))) else None

// we need these 2 separate signals because ext_mem_io is not writable in this module
val ext_mem_spad = outer.spad.module.io.ext_mem.get.spad
val ext_mem_acc = outer.spad.module.io.ext_mem.get.acc

// connecting to unified TL interface
val source_counters = Seq.fill(4)(Counter(outer.num_ids))

if (outer.create_tl_mem) {
def connect(ext_mem: ExtMemIO, req_size: Int, r_node: TLBundle, r_edge: TLEdgeOut, r_source: Counter,
w_node: TLBundle, w_edge: TLEdgeOut, w_source: Counter): Unit = {
r_node.a.valid := ext_mem.read_req.valid
r_node.a.bits := r_edge.Get(r_source.value,
(ext_mem.read_req.bits << req_size.U).asUInt | outer.spad_base.U,
req_size.U)._2
ext_mem.read_req.ready := r_node.a.ready

val w_shifted_addr = (ext_mem.write_req.bits.addr << req_size.U).asUInt
val w_mask = (ext_mem.write_req.bits.mask << (w_shifted_addr & (w_edge.manager.beatBytes - 1).U)).asUInt

w_node.a.valid := ext_mem.write_req.valid
w_node.a.bits := w_edge.Put(w_source.value,
w_shifted_addr | outer.spad_base.U,
req_size.U, ext_mem.write_req.bits.data, w_mask)._2
ext_mem.write_req.ready := w_node.a.ready

ext_mem.read_resp.valid := r_node.d.valid
ext_mem.read_resp.bits := r_node.d.bits.data
r_node.d.ready := ext_mem.read_resp.ready

w_node.d.ready := true.B // writes are not acknowledged in gemmini

when(ext_mem.read_req.fire) { r_source.inc() }
when(ext_mem.write_req.fire) { w_source.inc() }
}
(outer.spad_read_nodes.out zip outer.spad_write_nodes.out)
.zipWithIndex.foreach{ case (((r_node, r_edge), (w_node, w_edge)), i) =>
connect(ext_mem_spad(i), log2Up(outer.spad_data_len),
r_node, r_edge, source_counters(0), w_node, w_edge, source_counters(1))
}

outer.spad_tl_ram.foreach { case Seq(r, w) =>
val mem_depth = outer.config.sp_bank_entries * outer.spad_data_len / outer.max_data_len
val mem_width = outer.max_data_len

val mem = TwoPortSyncMem(
n = mem_depth,
t = UInt((mem_width * 8).W),
mask_len = mem_width // byte level mask
)

val (r_node, r_edge) = r.in.head
val (w_node, w_edge) = w.in.head

// READ
mem.io.ren := r_node.a.fire
mem.io.raddr := r_node.a.bits.address ^ outer.spad_base.U

val data_pipe_in = Wire(DecoupledIO(mem.io.rdata.cloneType))
data_pipe_in.valid := RegNext(mem.io.ren)
data_pipe_in.bits := mem.io.rdata

val metadata_pipe_in = Wire(DecoupledIO(new Bundle {
val source = r_node.a.bits.source.cloneType
val size = r_node.a.bits.size.cloneType
}))
metadata_pipe_in.valid := mem.io.ren
metadata_pipe_in.bits.source := r_node.a.bits.source
metadata_pipe_in.bits.size := r_node.a.bits.size

val data_pipe_inst = Module(new Pipeline(data_pipe_in.bits.cloneType, 1)())
data_pipe_inst.io.in <> data_pipe_in
val data_pipe = data_pipe_inst.io.out
val metadata_pipe = Pipeline(metadata_pipe_in, 2)
assert(data_pipe_in.ready || !data_pipe_in.valid)
assert(metadata_pipe_in.ready || !data_pipe_in.ready)
assert(data_pipe.valid === metadata_pipe.valid)

r_node.d.bits := r_edge.AccessAck(
metadata_pipe.bits.source,
metadata_pipe.bits.size,
data_pipe.bits)
r_node.d.valid := data_pipe.valid
// take new requests only we have the buffer slot open in case downstream becomes unready
r_node.a.ready := r_node.d.ready && !data_pipe_inst.io.busy
data_pipe.ready := r_node.d.ready
metadata_pipe.ready := r_node.d.ready

// WRITE
mem.io.wen := w_node.a.fire
mem.io.waddr := w_node.a.bits.address ^ outer.spad_base.U
mem.io.wdata := w_node.a.bits.data
mem.io.mask := w_node.a.bits.mask.asBools
w_node.a.ready := w_node.d.ready// && (mem.io.waddr =/= mem.io.raddr)
w_node.d.valid := w_node.a.valid
w_node.d.bits := w_edge.AccessAck(w_node.a.bits)
}

ext_mem_acc.foreach(_.foreach(x => {
x.read_resp.bits := 0.U(1.W)
x.read_resp.valid := false.B
x.read_req.ready := false.B
x.write_req.ready := false.B
}))
// (outer.acc_read_nodes.out zip outer.acc_write_nodes.out)
// .zipWithIndex.foreach { case (((r_node, r_edge), (w_node, w_edge)), i) =>
// // TODO (richard): one subbank only for now
// connect(ext_mem_acc(i)(0), log2Up(outer.acc_data_len),
// r_node, r_edge, source_counters(2), w_node, w_edge, source_counters(3))
// }
} else if (use_shared_ext_mem) {
ext_mem_io.foreach(_ <> outer.spad.module.io.ext_mem.get)
}

val tagWidth = 32

Expand All @@ -66,7 +282,8 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]

// TLB
implicit val edge = outer.spad.id_node.edges.out.head
val tlb = Module(new FrontendTLB(2, tlb_size, dma_maxbytes, use_tlb_register_filter, use_firesim_simulation_counters, use_shared_tlb))
// TODO(richard): bypass TLB
val tlb = Module(new FrontendTLB(3, tlb_size, dma_maxbytes, use_tlb_register_filter, use_firesim_simulation_counters, use_shared_tlb))
(tlb.io.clients zip outer.spad.module.io.tlb).foreach(t => t._1 <> t._2)

tlb.io.exp.foreach(_.flush_skip := false.B)
Expand Down
12 changes: 11 additions & 1 deletion src/main/scala/gemmini/CustomConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,18 @@ object GemminiCustomConfigs {
acc_capacity = CapacityInKilobytes(128),
)

val unifiedMemConfig = defaultConfig.copy(
has_training_convs = false,
has_max_pool = false,
use_tl_ext_mem = true,
sp_singleported = false,
spad_read_delay = 8,
use_shared_ext_mem = true,
acc_sub_banks = 1
)

// Specify which of your custom configs you want to build here
val customConfig = baselineInferenceConfig
val customConfig = unifiedMemConfig
}


Expand Down
Loading

0 comments on commit dffcfd1

Please sign in to comment.