Skip to content

Commit

Permalink
add scaling support in FSM
Browse files Browse the repository at this point in the history
  • Loading branch information
Seah Kim authored and Seah Kim committed Oct 16, 2023
1 parent 44f39f7 commit 6193a7a
Showing 1 changed file with 68 additions and 45 deletions.
113 changes: 68 additions & 45 deletions src/main/scala/gemmini/VegaLoopMatmul.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class VegaLoopMatmulLdAReq(val block_size: Int, val coreMaxAddrBits: Int, val it
//val transpose = Bool()
val addr_start = UInt(log2Up(max_addr).W)
val loop_id = UInt(log2Up(concurrent_loops).W)
//val is_resadd = Bool()
}

class VegaLoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, input_w: Int,
Expand Down Expand Up @@ -79,6 +78,10 @@ class VegaLoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
val blocks = Mux(col_iterator + max_blocks <= max_col_iterator, max_blocks, max_col_iterator-col_iterator)
val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U)
val rows = block_size.U - Mux(row_iterator === max_row_iterator-1.U, row_pad, 0.U)
dontTouch(dram_addr)
dontTouch(sp_addr)
dontTouch(rows)
dontTouch(cols)

val mvin_cmd = Wire(new RoCCCommand)
mvin_cmd := DontCare
Expand Down Expand Up @@ -145,7 +148,7 @@ class VegaLoopMatmulLdBReq(val block_size: Int, val coreMaxAddrBits: Int, val it
//val transpose = Bool()
val addr_end = UInt(log2Up(max_addr+1).W)
val loop_id = UInt(log2Up(concurrent_loops).W)
//val is_resadd = Bool()
val is_scale = Bool()
}

class VegaLoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, input_w: Int,
Expand Down Expand Up @@ -199,8 +202,8 @@ class VegaLoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
val max_col_dim = req.max_k
val max_blocks = Mux(max_col_dim <= max_block_len.U, max_col_dim, max_block_len.U)

//val sp_addr_start = Mux(req.is_resadd, req.addr_end, req.addr_end - req.max_k * req.max_j * block_size.U)
val sp_addr_start = req.addr_end - req.max_k
val sp_addr_start = Mux(req.is_scale, req.addr_end, req.addr_end - req.max_k)
//val sp_addr_start = req.addr_end - req.max_k

//val dram_offset = (row_iterator * req.dram_stride + col_iterator) * block_size.U * (input_w/8).U
val dram_offset = col_iterator * block_size.U * (input_w/8).U
Expand All @@ -210,9 +213,9 @@ class VegaLoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
val blocks = Mux(col_iterator + max_blocks <= max_col_iterator, max_blocks, max_col_iterator-col_iterator)
val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U)
val rows = 1.U //block_size.U - Mux(max_row_iterator === max_row_iterator-1.U, row_pad, 0.U)
//dontTouch(dram_addr)
//dontTouch(sp_addr)
//dontTouch(cols)
dontTouch(dram_addr)
dontTouch(sp_addr)
dontTouch(cols)

val mvin_cmd = Wire(new RoCCCommand)
mvin_cmd := DontCare
Expand All @@ -225,11 +228,11 @@ class VegaLoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
mvin_cmd_rs2.num_cols := cols.asUInt
mvin_cmd_rs2.local_addr := cast_to_sp_addr(mvin_cmd_rs2.local_addr, sp_addr)
mvin_cmd.rs2 := mvin_cmd_rs2.asUInt
/*
when (req.is_resadd){
mvin_cmd_rs2.local_addr := cast_to_acc_addr(mvin_cmd_rs2.local_addr, sp_addr, accumulate = true.B, read_full = false.B)

when (req.is_scale){
mvin_cmd_rs2.local_addr := cast_to_acc_addr(mvin_cmd_rs2.local_addr, sp_addr, accumulate = false.B, read_full = false.B)
}
*/

io.req.ready := state === idle
io.k := k
//io.j := j
Expand Down Expand Up @@ -386,7 +389,7 @@ class VegaLoopMatmulExecuteReq(val block_size: Int, val coreMaxAddrBits: Int, va
val b_addr_end = UInt(log2Up(max_addr+1).W)
val c_addr_start = UInt(log2Up(max_acc_addr).W)
val loop_id = UInt(log2Up(concurrent_loops).W)
//val skip = Bool()
val skip = Bool()
}

class VegaLoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, max_acc_addr: Int, concurrent_loops: Int,
Expand Down Expand Up @@ -513,12 +516,14 @@ class VegaLoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitw
val ldd_ahead = io.ldd_completed
val ld_ahead = lda_ahead && ldb_ahead && ldd_ahead

io.cmd.valid := state =/= idle && !io.rob_overloaded && ld_ahead // && !req.skip
io.cmd.valid := state =/= idle && !io.rob_overloaded && ld_ahead && !req.skip
io.cmd.bits := Mux(state === pre, pre_cmd, comp_cmd)

io.loop_id := req.loop_id

when (io.cmd.fire) {
when(req.skip){
state := idle
}.elsewhen (io.cmd.fire) {
when (state === pre) {
state := comp
}.otherwise {
Expand Down Expand Up @@ -559,7 +564,7 @@ class VegaLoopMatmulStCReq(val block_size: Int, val coreMaxAddrBits: Int, val it
val act = UInt(Activation.bitwidth.W)
val addr_start = UInt(log2Up(max_acc_addr).W)
val loop_id = UInt(log2Up(concurrent_loops).W)
//val is_resadd = Bool()
val is_scale = Bool()
}

class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, max_block_len: Int, concurrent_loops: Int, mvout_rs2_t: MvoutRs2)
Expand All @@ -583,7 +588,7 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
})

object State extends ChiselEnum {
val idle, st, ln_config, ln_st = Value
val idle, st = Value
}
import State._
val state = RegInit(idle)
Expand All @@ -605,11 +610,20 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
//val sp_addr = acc_addr_start + (i * req.max_j + j) * block_size.U
val sp_addr = acc_addr_start + i
val blocks = 1.U //Mux(j + max_blocks <= req.max_j, max_blocks, req.max_j-j)
val cols = block_size.U //(blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U)
val cols = WireInit(block_size.U) //(blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U)
//val rows = block_size.U - Mux(i === req.max_i-1.U, req.pad_i, 0.U)
val rows = Mux(i + block_size.U > req.max_i, req.max_i - i, block_size.U)
//dontTouch(dram_addr)
//dontTouch(sp_addr)
val rows = WireInit(Mux(i + block_size.U > req.max_i, req.max_i - i, block_size.U))
//val rows = WireInit(block_size.U)
//when(req.is_scale){
when(req.pad_i =/= 0.U && i + block_size.U >= req.max_i && req.is_scale){
rows := 1.U
cols := Mux(i === req.max_i - 1.U, block_size.U - req.pad_i, block_size.U)
}

dontTouch(dram_addr)
dontTouch(sp_addr)
dontTouch(rows)
dontTouch(cols)

val mvout_cmd = Wire(new RoCCCommand)
mvout_cmd := DontCare
Expand All @@ -621,6 +635,12 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
mvout_cmd_rs2.num_rows := rows.asUInt
mvout_cmd_rs2.num_cols := cols.asUInt
mvout_cmd_rs2.local_addr := cast_to_acc_addr(mvout_cmd_rs2.local_addr, sp_addr, accumulate = false.B, read_full = req.full_c)
/*
when(req.is_scale){
mvout_cmd_rs2.local_addr := cast_to_sp_addr(mvout_cmd_rs2.local_addr, sp_addr)
}
*/
mvout_cmd.rs2 := mvout_cmd_rs2.asUInt

io.req.ready := state === idle
Expand All @@ -645,7 +665,9 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
(io.ex_k === req.max_k - 1.U &&
(io.ex_i >= i + block_size.U)))
//(io.ex_i > i)))

when(req.is_scale){
ex_ahead := io.ex_completed || (io.ex_i >= i + block_size.U)
}
io.cmd.valid := state =/= idle && !io.rob_overloaded && ex_ahead && req.dram_addr =/= 0.U
io.cmd.bits := mvout_cmd

Expand All @@ -655,8 +677,9 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth
state := idle
}.elsewhen (io.cmd.fire() && state === st) {
// The order here is k, j, i
val next_i_size = Mux(i + block_size.U >= req.max_i && req.pad_i =/= 0.U && req.is_scale, 1.U, block_size.U)
//val next_i = floorAdd(i, 1.U, req.max_i)
val next_i = floorAdd(i, block_size.U, req.max_i)
val next_i = floorAdd(i, next_i_size, req.max_i)
//val next_j = floorAdd(j, max_blocks, req.max_j, next_i === 0.U)

i := next_i
Expand Down Expand Up @@ -752,7 +775,7 @@ class VegaLoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int,

val a_addr_start = UInt(log2Up(max_addr).W)
val b_addr_end = UInt(log2Up(max_addr+1).W)
val resadd_addr_start = UInt(log2Up(max_acc_addr).W)
val scale_addr_start = UInt(log2Up(max_acc_addr).W)

def reset(): Unit = {
configured := false.B
Expand Down Expand Up @@ -808,7 +831,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
val loop_being_configured_id = Mux(head_loop.configured, tail_loop_id, head_loop_id)
val loop_being_configured = loops(loop_being_configured_id)

//val is_resadd = RegInit(false.B)
val is_scale = RegInit(false.B)

val max_all_addr = if(max_addr > max_acc_addr) max_addr else max_acc_addr
// Create inner modules
Expand All @@ -830,8 +853,8 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
val ab_loads_on_same_loop = ldA.io.loop_id === ldB.io.loop_id
val forceA = !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id
val forceB = !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id
ldab_arb.io.forceA := forceA //Mux(is_resadd, ab_loads_on_same_loop && !ldA.io.idle, forceA)
ldab_arb.io.forceB := forceB //Mux(is_resadd, forceB || ldA.io.idle, forceB)
ldab_arb.io.forceA := Mux(is_scale, false.B, forceA) //Mux(is_resadd, ab_loads_on_same_loop && !ldA.io.idle, forceA)
ldab_arb.io.forceB := Mux(is_scale, true.B, forceB) //Mux(is_resadd, forceB || ldA.io.idle, forceB)
ldab_arb.io.weightA := 0.U
ldab_arb.io.inA_idle := ldA.io.idle
ldab_arb.io.inB_idle := ldB.io.idle
Expand Down Expand Up @@ -909,15 +932,15 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
// when loop matmul is used as resadd unroller
// skip ex
// track ldB instead of ex
/*
when(is_resadd){
stC.io.ex_completed := (ldA.io.loop_id =/= stC.io.loop_id || ldA.io.idle) && (ldB.io.loop_id =/= stC.io.loop_id || ldB.io.idle)

when(is_scale){
stC.io.ex_completed := (ldB.io.loop_id =/= stC.io.loop_id || ldB.io.idle)// && (ldB.io.loop_id =/= stC.io.loop_id || ldB.io.idle)
stC.io.ex_k := 0.U // req.max_k shall be 1
//stC.io.ex_j := ldB.io.j
stC.io.ex_i := ldB.io.k
//ldB.io.rob_overloaded := ld_utilization >= max_lds.U || !((ldA.io.loop_id =/= ldB.io.loop_id) || ldA.io.idle)
}
*/


val loops_configured = RegInit(0.U(16.W))
//dontTouch(loops_configured)
Expand Down Expand Up @@ -966,7 +989,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
loop_being_configured.b_ex_spad_id := cmd.bits.cmd.rs1(17, 16)
//loop_being_configured.a_transpose := cmd.bits.cmd.rs2(0)
//loop_being_configured.b_transpose := cmd.bits.cmd.rs2(1)
//is_resadd := cmd.bits.cmd.rs2(2)
is_scale := cmd.bits.cmd.rs1(32)
loop_being_configured.a_dram_stride := cmd.bits.cmd.rs2

loop_being_configured.configured := true.B
Expand All @@ -992,7 +1015,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
//ldA.io.req.bits.transpose := loop_requesting_ldA.a_transpose
ldA.io.req.bits.addr_start := Mux(loop_requesting_ldA.a_ex_spad_id === 0.U, loop_requesting_ldA.a_addr_start, (loop_requesting_ldA.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U)
ldA.io.req.bits.loop_id := loop_requesting_ldA_id
//ldA.io.req.bits.is_resadd := is_resadd
//ldA.io.req.bits.is_scale := is_scale

ldA.io.req.valid := !loop_requesting_ldA.lda_started && loop_requesting_ldA.configured

Expand All @@ -1012,7 +1035,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
//ldB.io.req.bits.transpose := loop_requesting_ldB.b_transpose
ldB.io.req.bits.addr_end := Mux(loop_requesting_ldB.b_ex_spad_id === 0.U, loop_requesting_ldB.b_addr_end, (loop_requesting_ldB.b_ex_spad_id) * (max_addr / concurrent_loops).U)
ldB.io.req.bits.loop_id := loop_requesting_ldB_id
//ldB.io.req.bits.is_resadd := is_resadd
ldB.io.req.bits.is_scale := is_scale

ldB.io.req.valid := !loop_requesting_ldB.ldb_started && loop_requesting_ldB.configured

Expand All @@ -1036,7 +1059,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
//ex.io.req.bits.b_tranpose := loop_requesting_ex.b_transpose
ex.io.req.bits.c_addr_start := ex_c_addr_start
ex.io.req.bits.loop_id := loop_requesting_ex_id
//ex.io.req.bits.skip := is_resadd
ex.io.req.bits.skip := is_scale

ex.io.req.valid := !loop_requesting_ex.ex_started && loop_requesting_ex.lda_started &&
loop_requesting_ex.ldb_started && loop_requesting_ex.ldd_started && loop_requesting_ex.configured
Expand Down Expand Up @@ -1075,18 +1098,18 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_

val loop_requesting_st_id = Mux(head_loop.st_started, tail_loop_id, head_loop_id)
val loop_requesting_st = loops(loop_requesting_st_id)
stC.io.req.bits.max_k := loop_requesting_st.max_k //Mux(is_resadd, 1.U, loop_requesting_st.max_k)
stC.io.req.bits.max_k := Mux(is_scale, loop_requesting_st.max_i, loop_requesting_st.max_k) //Mux(is_resadd, 1.U, loop_requesting_st.max_k)
//stC.io.req.bits.max_j := loop_requesting_st.max_j
stC.io.req.bits.max_i := loop_requesting_st.max_i
stC.io.req.bits.max_i := Mux(is_scale, loop_requesting_st.max_k, loop_requesting_st.max_i)
//stC.io.req.bits.pad_j := loop_requesting_st.pad_j
stC.io.req.bits.pad_i := loop_requesting_st.pad_i
stC.io.req.bits.pad_i := loop_requesting_st.pad_i //Mux(is_scale, loop_requesting_st.pad_k, loop_requesting_st.pad_i)
stC.io.req.bits.dram_addr := loop_requesting_st.c_dram_addr
//stC.io.req.bits.dram_stride := loop_requesting_st.c_dram_stride
stC.io.req.bits.full_c := loop_requesting_st.full_c
stC.io.req.bits.act := loop_requesting_st.act
stC.io.req.bits.addr_start := st_c_addr_start
stC.io.req.bits.loop_id := loop_requesting_st_id
//stC.io.req.bits.is_resadd := is_resadd
stC.io.req.bits.is_scale := is_scale


stC.io.req.valid := !loop_requesting_st.st_started && loop_requesting_st.ex_started && loop_requesting_st.configured
Expand All @@ -1099,15 +1122,15 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
st_c_addr_start := floorAdd(st_c_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U)
}
}
/*
when(is_resadd){
ldA.io.req.bits.addr_start := loop_requesting_ldA.resadd_addr_start
ldB.io.req.bits.addr_end := loop_requesting_ldB.resadd_addr_start
stC.io.req.bits.addr_start := loop_requesting_st.resadd_addr_start

when(is_scale){
//ldA.io.req.bits.addr_start := loop_requesting_ldA.scale_addr_start
ldB.io.req.bits.addr_end := loop_requesting_ldB.scale_addr_start
stC.io.req.bits.addr_start := loop_requesting_st.scale_addr_start
stC.io.req.valid := !loop_requesting_st.st_started && loop_requesting_st.configured
}

*/

// Handle completed signals
when (ldA.io.idle && loops(ldA.io.loop_id).running && loops(ldA.io.loop_id).lda_started) {
loops(ldA.io.loop_id).lda_completed := true.B
Expand Down Expand Up @@ -1140,7 +1163,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_
l.reset()
l.a_addr_start := (i * (max_addr / concurrent_loops)).U
l.b_addr_end := ((i+1) * (max_addr / concurrent_loops)).U
l.resadd_addr_start := (i * (max_acc_addr / concurrent_loops)).U
l.scale_addr_start := (i * (max_acc_addr / concurrent_loops)).U
}
}
}
Expand Down

0 comments on commit 6193a7a

Please sign in to comment.