Skip to content

Commit

Permalink
Make LoopMatmul FSM prefetch ahead
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryz123 committed Mar 9, 2021
1 parent 7f77625 commit 55c28b0
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 46 deletions.
5 changes: 4 additions & 1 deletion src/main/scala/gemmini/Controller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
val max_lds = rob_entries * 1 / 4
val max_exs = rob_entries * 3 / 4
val max_sts = rob_entries * 1 / 8
val (loop_cmd, loop_matmul_unroller_busy) = LoopMatmul(raw_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization,
val (loop_cmd, loop_matmul_unroller_busy, prefetch) = LoopMatmul(raw_cmd, rob.io.ld_utilization, rob.io.st_utilization, rob.io.ex_utilization,
meshRows*tileRows, coreMaxAddrBits, rob_entries, max_lds, max_exs, max_sts, sp_banks * sp_bank_entries, acc_banks * acc_bank_entries,
inputType.getWidth, accType.getWidth, dma_maxbytes)
val unrolled_cmd = Queue(loop_cmd)
Expand Down Expand Up @@ -303,6 +303,9 @@ class GemminiModule[T <: Data: Arithmetic, U <: Data, V <: Data]
ex_controller.io.acc.read <> spad.module.io.acc.read
ex_controller.io.acc.write <> spad.module.io.acc.write

spad.module.io.prefetch <> prefetch
prefetch.ready := spad.module.io.prefetch.ready

// Im2Col unit
val im2col = Module(new Im2Col(outer.config))

Expand Down
76 changes: 51 additions & 25 deletions src/main/scala/gemmini/DMA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import chisel3.experimental.DataMirror

import freechips.rocketchip.config.Parameters
import freechips.rocketchip.diplomacy.{IdRange, LazyModule, LazyModuleImp}
import freechips.rocketchip.tile.{CoreBundle, HasCoreParameters}
import freechips.rocketchip.tilelink.TLBundleA
import freechips.rocketchip.tile.{CoreBundle, HasCoreParameters, RoCCCommand}
import freechips.rocketchip.tilelink.{TLBundleA, TLHints, TLMessages}
import testchipip.TLHelper
import freechips.rocketchip.rocket.MStatus
import freechips.rocketchip.rocket.constants.MemoryOpConstants
Expand Down Expand Up @@ -58,6 +58,7 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T
val io = IO(new Bundle {
val req = Flipped(Decoupled(new StreamReadRequest(spad_rows, acc_rows, config.mvin_scale_t_bits)))
val resp = Decoupled(new StreamReadResponse(spadWidth, accWidth, spad_rows, acc_rows, aligned_to, config.mvin_scale_t_bits))
val prefetch = Flipped(Decoupled(new RoCCCommand))
val tlb = new FrontendTLBIO
val busy = Output(Bool())
val flush = Input(Bool())
Expand All @@ -70,16 +71,17 @@ class StreamReader[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T
val beatPacker = Module(new BeatMerger(beatBits, maxBytes, spadWidth, accWidth, spad_rows, acc_rows, maxBytes, aligned_to, meshRows, config.mvin_scale_t_bits, nCmds))

core.module.io.req <> io.req
core.module.io.prefetch <> io.prefetch
io.tlb <> core.module.io.tlb
io.busy := xactTracker.io.busy
core.module.io.flush := io.flush

xactTracker.io.alloc <> core.module.io.reserve
xactTracker.io.peek.xactid := RegEnableThru(core.module.io.beatData.bits.xactid, beatPacker.io.req.fire())
xactTracker.io.peek.pop := beatPacker.io.in.fire() && core.module.io.beatData.bits.last
xactTracker.io.peek.xactid := RegEnableThru(core.module.io.beatData.bits.xactid, core.module.io.beatData.fire())
xactTracker.io.peek.pop := core.module.io.beatData.fire() && core.module.io.beatData.bits.last

core.module.io.beatData.ready := beatPacker.io.in.ready
beatPacker.io.req.valid := core.module.io.beatData.valid
core.module.io.beatData.ready := beatPacker.io.in.ready || core.module.io.beatData.bits.is_hintack
beatPacker.io.req.valid := core.module.io.beatData.valid && !core.module.io.beatData.bits.is_hintack
beatPacker.io.req.bits := xactTracker.io.peek.entry
beatPacker.io.req.bits.lg_len_req := core.module.io.beatData.bits.lg_len_req
beatPacker.io.in.valid := core.module.io.beatData.valid
Expand All @@ -106,6 +108,7 @@ class StreamReadBeat (val nXacts: Int, val beatBits: Int, val maxReqBytes: Int)
val data = UInt(beatBits.W)
val lg_len_req = UInt(log2Up(log2Up(maxReqBytes+1)+1).W)
val last = Bool()
val is_hintack = Bool()
}

// TODO StreamReaderCore and StreamWriter are actually very alike. Is there some parent class they could both inherit from?
Expand All @@ -131,6 +134,7 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf

val io = IO(new Bundle {
val req = Flipped(Decoupled(new StreamReadRequest(spad_rows, acc_rows, config.mvin_scale_t_bits)))
val prefetch = Flipped(Decoupled(new RoCCCommand))
val reserve = new XactTrackerAllocIO(nXacts, maxBytes, spadWidth, accWidth, spad_rows, acc_rows, maxBytes, config.mvin_scale_t_bits, nCmds)
val beatData = Decoupled(new StreamReadBeat(nXacts, beatBits, maxBytes))
val tlb = new FrontendTLBIO
Expand Down Expand Up @@ -193,17 +197,53 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf
lgSize = read_lg_size
)._2

val prefetch = edge.Hint(
fromSource = io.reserve.xactid,
toAddress = 0.U,
lgSize = 1.U,
param = TLHints.PREFETCH_READ
)._2

class TLBundleAWithInfo extends Bundle {
val tl_a = DataMirror.internal.chiselTypeClone[TLBundleA](tl.a.bits)
val vaddr = Output(UInt(vaddrBits.W))
val status = Output(new MStatus)
}

val untranslated_a = Wire(Decoupled(new TLBundleAWithInfo))
untranslated_a.valid := state === s_req_new_block && io.reserve.ready
untranslated_a.bits.tl_a := get
untranslated_a.bits.vaddr := read_vaddr
untranslated_a.bits.status := req.status
untranslated_a.valid := false.B
untranslated_a.bits := DontCare
io.prefetch.ready := false.B
io.reserve.valid := false.B
when (state === s_req_new_block) {
io.reserve.valid := untranslated_a.ready
untranslated_a.valid := io.reserve.ready
untranslated_a.bits.tl_a := get
untranslated_a.bits.vaddr := read_vaddr
untranslated_a.bits.status := req.status

when (untranslated_a.fire()) {
val next_vaddr = req.vaddr + read_bytes_read // send_size
val new_page = next_vaddr(pgIdxBits-1, 0) === 0.U
req.vaddr := next_vaddr

bytesRequested := bytesRequested + read_bytes_read // send_size

// when (send_size >= bytesLeft) {
when (read_bytes_read >= bytesLeft) {
// We're done with this request at this point
state_machine_ready_for_req := true.B
state := s_idle
}
}
} .elsewhen (io.prefetch.valid) {
io.reserve.valid := untranslated_a.ready
untranslated_a.valid := io.reserve.ready
io.prefetch.ready := untranslated_a.ready && io.reserve.ready
untranslated_a.bits.tl_a := prefetch
untranslated_a.bits.vaddr := io.prefetch.bits.rs1
untranslated_a.bits.status := io.prefetch.bits.status
}

// 0 goes to retries, 1 goes to state machine
val retry_a = Wire(Decoupled(new TLBundleAWithInfo))
Expand Down Expand Up @@ -233,7 +273,6 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf
tl.a.bits := translate_q.io.deq.bits.tl_a
tl.a.bits.address := io.tlb.resp.paddr

io.reserve.valid := state === s_req_new_block && untranslated_a.ready // TODO decouple "reserve.valid" from "tl.a.ready"
io.reserve.entry.shift := read_shift
io.reserve.entry.is_acc := req.is_acc
io.reserve.entry.accumulate := req.accumulate
Expand All @@ -253,20 +292,6 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf
if (bytesRequested.getWidth >= log2Up(spadWidthBytes+1)) bytesRequested / spadWidthBytes.U else 0.U)
io.reserve.entry.spad_row_offset := Mux(req.has_acc_bitwidth, bytesRequested % accWidthBytes.U, bytesRequested % spadWidthBytes.U)

when (untranslated_a.fire()) {
val next_vaddr = req.vaddr + read_bytes_read // send_size
val new_page = next_vaddr(pgIdxBits-1, 0) === 0.U
req.vaddr := next_vaddr

bytesRequested := bytesRequested + read_bytes_read // send_size

// when (send_size >= bytesLeft) {
when (read_bytes_read >= bytesLeft) {
// We're done with this request at this point
state_machine_ready_for_req := true.B
state := s_idle
}
}

// Forward TileLink read responses to the reservation buffer
tl.d.ready := io.beatData.ready
Expand All @@ -275,6 +300,7 @@ class StreamReaderCore[T <: Data, U <: Data, V <: Data](config: GemminiArrayConf
io.beatData.bits.data := tl.d.bits.data
io.beatData.bits.lg_len_req := tl.d.bits.size
io.beatData.bits.last := edge.last(tl.d)
io.beatData.bits.is_hintack := tl.d.bits.opcode === TLMessages.HintAck
// TODO the size data is already returned from TileLink, so there's no need for us to store it in the XactTracker ourselves

// Accepting requests to kick-start the state machine
Expand Down
88 changes: 68 additions & 20 deletions src/main/scala/gemmini/LoopMatmul.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val io = IO(new Bundle {
val req = Flipped(Decoupled(new LoopMatmulLdAReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops)))
val cmd = Decoupled(Output(new RoCCCommand))
val prefetch = Output(Valid(new RoCCCommand))
val i = Output(UInt(iterator_bitwidth.W))
val k = Output(UInt(iterator_bitwidth.W))
val idle = Output(Bool())
Expand Down Expand Up @@ -72,17 +73,31 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
mvin_cmd.rs1 := dram_addr
mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr

io.req.ready := state === idle
io.i := i
io.k := k
io.idle := state === idle

io.cmd.valid := state =/= idle && !io.rob_overloaded
io.cmd.bits := mvin_cmd

class CmdQEntry extends Bundle {
val cmd = new RoCCCommand
val i = UInt()
val k = UInt()
}
val cmd_q = Module(new Queue(new CmdQEntry, 8))
cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded
cmd_q.io.enq.bits.cmd := mvin_cmd
cmd_q.io.enq.bits.i := i
cmd_q.io.enq.bits.k := k

io.cmd.valid := cmd_q.io.deq.valid
cmd_q.io.deq.ready := io.cmd.ready
io.cmd.bits := cmd_q.io.deq.bits.cmd

io.prefetch.valid := cmd_q.io.enq.fire()
io.prefetch.bits := cmd_q.io.enq.bits.cmd

io.req.ready := state === idle && !cmd_q.io.deq.valid
io.i := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.i, 0.U)
io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U)
io.loop_id := req.loop_id
io.idle := state === idle && !cmd_q.io.deq.valid

when (io.cmd.fire()) {
when (cmd_q.io.enq.fire()) {
// The order here is k, j, i
val i_blocks = Mux(req.transpose, max_blocks, 1.U)
val k_blocks = Mux(req.transpose, 1.U, max_blocks)
Expand Down Expand Up @@ -126,6 +141,7 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val io = IO(new Bundle {
val req = Flipped(Decoupled(new LoopMatmulLdBReq(block_size, coreMaxAddrBits, iterator_bitwidth, max_addr, concurrent_loops)))
val cmd = Decoupled(Output(new RoCCCommand))
val prefetch = Output(Valid(new RoCCCommand))

val k = Output(UInt(iterator_bitwidth.W))
val j = Output(UInt(iterator_bitwidth.W))
Expand Down Expand Up @@ -173,17 +189,31 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
mvin_cmd.rs1 := dram_addr
mvin_cmd.rs2 := (rows << 48).asUInt() | (cols << 32).asUInt() | sp_addr

io.req.ready := state === idle
io.k := k
io.j := j
io.idle := state === idle

io.cmd.valid := state =/= idle && !io.rob_overloaded
io.cmd.bits := mvin_cmd

class CmdQEntry extends Bundle {
val cmd = new RoCCCommand
val k = UInt()
val j = UInt()
}
val cmd_q = Module(new Queue(new CmdQEntry, 8))
cmd_q.io.enq.valid := state =/= idle && !io.rob_overloaded
cmd_q.io.enq.bits.cmd := mvin_cmd
cmd_q.io.enq.bits.k := k
cmd_q.io.enq.bits.j := j

io.cmd.valid := cmd_q.io.deq.valid
cmd_q.io.deq.ready := io.cmd.ready
io.cmd.bits := cmd_q.io.deq.bits.cmd

io.prefetch.valid := cmd_q.io.enq.fire()
io.prefetch.bits := cmd_q.io.enq.bits.cmd

io.req.ready := state === idle && !cmd_q.io.deq.valid
io.k := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.k, 0.U)
io.j := Mux(cmd_q.io.deq.valid, cmd_q.io.deq.bits.j, 0.U)
io.loop_id := req.loop_id
io.idle := state === idle && !cmd_q.io.deq.valid

when (io.cmd.fire()) {
when (cmd_q.io.enq.fire()) {
// The order here is k, j, i
val j_blocks = Mux(req.transpose, 1.U, max_blocks)
val k_blocks = Mux(req.transpose, max_blocks, 1.U)
Expand Down Expand Up @@ -606,6 +636,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds:
val io = IO(new Bundle {
val in = Flipped(Decoupled(new RoCCCommand))
val out = Decoupled(new RoCCCommand)
val prefetch = Decoupled(new RoCCCommand)
val ld_utilization = Input(UInt(log2Up(rob_size+1).W))
val st_utilization = Input(UInt(log2Up(rob_size+1).W))
val ex_utilization = Input(UInt(log2Up(rob_size+1).W))
Expand Down Expand Up @@ -645,6 +676,23 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds:
ldab_arb.io.forceA := !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id
ldab_arb.io.forceB := !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id

val prefetch_arb = Module(new Arbiter(new RoCCCommand, 2))
prefetch_arb.io.in(0).valid := ldA.io.prefetch.fire()
prefetch_arb.io.in(0).bits := ldA.io.prefetch.bits
prefetch_arb.io.in(1).valid := ldB.io.prefetch.fire()
prefetch_arb.io.in(1).bits := ldB.io.prefetch.bits
val prefetch_q_size = 4
val prefetch_q = Module(new Queue(new RoCCCommand, prefetch_q_size, pipe=true))
io.prefetch <> prefetch_q.io.deq
io.prefetch.bits.status := cmd.bits.status // TODO This is not guaranteed to be the correct fix! We must fix this
prefetch_q.io.enq <> prefetch_arb.io.out
when (prefetch_q.io.enq.valid && prefetch_q.io.count === prefetch_q_size.U) {
prefetch_q.io.deq.ready := true.B
}

io.busy := cmd.valid || loop_configured


// Create global arbiter
val arb = Module(new Arbiter(new RoCCCommand(), 4))
arb.io.in(0) <> stC.io.cmd
Expand Down Expand Up @@ -890,13 +938,13 @@ object LoopMatmul {
def apply(in: DecoupledIO[RoCCCommand], ld_utilization: UInt, st_utilization: UInt, ex_utilization: UInt,
block_size: Int, coreMaxAddrBits: Int, rob_size: Int, max_lds: Int, max_exs: Int, max_sts: Int,
max_addr: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, dma_max_bytes: Int)
(implicit p: Parameters): Tuple2[DecoupledIO[RoCCCommand], Bool] = {
(implicit p: Parameters): (DecoupledIO[RoCCCommand], Bool, DecoupledIO[RoCCCommand]) = {
val mod = Module(new LoopMatmul(block_size, coreMaxAddrBits, rob_size, max_lds, max_exs, max_sts,
max_addr, max_acc_addr, input_w, acc_w, dma_max_bytes))
mod.io.in <> in
mod.io.ld_utilization := ld_utilization
mod.io.st_utilization := st_utilization
mod.io.ex_utilization := ex_utilization
(mod.io.out, mod.io.busy)
(mod.io.out, mod.io.busy, mod.io.prefetch)
}
}
3 changes: 3 additions & 0 deletions src/main/scala/gemmini/Scratchpad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
// Misc. ports
val busy = Output(Bool())
val flush = Input(Bool())

val prefetch = Flipped(Decoupled(new RoCCCommand))
})

val write_dispatch_q = Queue(io.dma.write.req)
Expand Down Expand Up @@ -269,6 +271,7 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
val (mvin_scale_in, mvin_scale_out) = VectorScalarMultiplier(config.mvin_scale_args, config.inputType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), is_acc = false)
val (mvin_scale_acc_in, mvin_scale_acc_out) = if (mvin_scale_shared) (mvin_scale_in, mvin_scale_out) else
VectorScalarMultiplier(config.mvin_scale_acc_args, config.accType, config.meshColumns * config.tileColumns, chiselTypeOf(reader.module.io.resp.bits), is_acc = true)
reader.module.io.prefetch <> io.prefetch

mvin_scale_in.valid := reader.module.io.resp.valid && (mvin_scale_shared.B || !reader.module.io.resp.bits.is_acc ||
(reader.module.io.resp.bits.is_acc && !reader.module.io.resp.bits.has_acc_bitwidth))
Expand Down

0 comments on commit 55c28b0

Please sign in to comment.