diff --git a/src/main/scala/gemmini/VegaLoopMatmul.scala b/src/main/scala/gemmini/VegaLoopMatmul.scala index bba27010..f9d17e48 100644 --- a/src/main/scala/gemmini/VegaLoopMatmul.scala +++ b/src/main/scala/gemmini/VegaLoopMatmul.scala @@ -22,7 +22,6 @@ class VegaLoopMatmulLdAReq(val block_size: Int, val coreMaxAddrBits: Int, val it //val transpose = Bool() val addr_start = UInt(log2Up(max_addr).W) val loop_id = UInt(log2Up(concurrent_loops).W) - //val is_resadd = Bool() } class VegaLoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, input_w: Int, @@ -79,6 +78,10 @@ class VegaLoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth val blocks = Mux(col_iterator + max_blocks <= max_col_iterator, max_blocks, max_col_iterator-col_iterator) val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U) val rows = block_size.U - Mux(row_iterator === max_row_iterator-1.U, row_pad, 0.U) + dontTouch(dram_addr) + dontTouch(sp_addr) + dontTouch(rows) + dontTouch(cols) val mvin_cmd = Wire(new RoCCCommand) mvin_cmd := DontCare @@ -145,7 +148,7 @@ class VegaLoopMatmulLdBReq(val block_size: Int, val coreMaxAddrBits: Int, val it //val transpose = Bool() val addr_end = UInt(log2Up(max_addr+1).W) val loop_id = UInt(log2Up(concurrent_loops).W) - //val is_resadd = Bool() + val is_scale = Bool() } class VegaLoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, input_w: Int, @@ -199,8 +202,8 @@ class VegaLoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth val max_col_dim = req.max_k val max_blocks = Mux(max_col_dim <= max_block_len.U, max_col_dim, max_block_len.U) - //val sp_addr_start = Mux(req.is_resadd, req.addr_end, req.addr_end - req.max_k * req.max_j * block_size.U) - val sp_addr_start = req.addr_end - req.max_k + val sp_addr_start = Mux(req.is_scale, req.addr_end, req.addr_end - req.max_k) + //val sp_addr_start = req.addr_end - req.max_k //val dram_offset = (row_iterator * req.dram_stride + col_iterator) * block_size.U * (input_w/8).U val dram_offset = col_iterator * block_size.U * (input_w/8).U @@ -210,9 +213,9 @@ class VegaLoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth val blocks = Mux(col_iterator + max_blocks <= max_col_iterator, max_blocks, max_col_iterator-col_iterator) val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U) val rows = 1.U //block_size.U - Mux(max_row_iterator === max_row_iterator-1.U, row_pad, 0.U) - //dontTouch(dram_addr) - //dontTouch(sp_addr) - //dontTouch(cols) + dontTouch(dram_addr) + dontTouch(sp_addr) + dontTouch(cols) val mvin_cmd = Wire(new RoCCCommand) mvin_cmd := DontCare @@ -225,11 +228,11 @@ class VegaLoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth mvin_cmd_rs2.num_cols := cols.asUInt mvin_cmd_rs2.local_addr := cast_to_sp_addr(mvin_cmd_rs2.local_addr, sp_addr) mvin_cmd.rs2 := mvin_cmd_rs2.asUInt -/* - when (req.is_resadd){ - mvin_cmd_rs2.local_addr := cast_to_acc_addr(mvin_cmd_rs2.local_addr, sp_addr, accumulate = true.B, read_full = false.B) + + when (req.is_scale){ + mvin_cmd_rs2.local_addr := cast_to_acc_addr(mvin_cmd_rs2.local_addr, sp_addr, accumulate = false.B, read_full = false.B) } -*/ + io.req.ready := state === idle io.k := k //io.j := j @@ -386,7 +389,7 @@ class VegaLoopMatmulExecuteReq(val block_size: Int, val coreMaxAddrBits: Int, va val b_addr_end = UInt(log2Up(max_addr+1).W) val c_addr_start = UInt(log2Up(max_acc_addr).W) val loop_id = UInt(log2Up(concurrent_loops).W) - //val skip = Bool() + val skip = Bool() } class VegaLoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_addr: Int, max_acc_addr: Int, concurrent_loops: Int, @@ -513,12 +516,14 @@ class VegaLoopMatmulExecute(block_size: Int, coreMaxAddrBits: Int, iterator_bitw val ldd_ahead = io.ldd_completed val ld_ahead = lda_ahead && ldb_ahead && ldd_ahead - io.cmd.valid := state =/= idle && !io.rob_overloaded && ld_ahead // && !req.skip + io.cmd.valid := state =/= idle && !io.rob_overloaded && ld_ahead && !req.skip io.cmd.bits := Mux(state === pre, pre_cmd, comp_cmd) io.loop_id := req.loop_id - when (io.cmd.fire) { + when(req.skip){ + state := idle + }.elsewhen (io.cmd.fire) { when (state === pre) { state := comp }.otherwise { @@ -559,7 +564,7 @@ class VegaLoopMatmulStCReq(val block_size: Int, val coreMaxAddrBits: Int, val it val act = UInt(Activation.bitwidth.W) val addr_start = UInt(log2Up(max_acc_addr).W) val loop_id = UInt(log2Up(concurrent_loops).W) - //val is_resadd = Bool() + val is_scale = Bool() } class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int, acc_w: Int, max_block_len: Int, concurrent_loops: Int, mvout_rs2_t: MvoutRs2) @@ -583,7 +588,7 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth }) object State extends ChiselEnum { - val idle, st, ln_config, ln_st = Value + val idle, st = Value } import State._ val state = RegInit(idle) @@ -605,11 +610,20 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth //val sp_addr = acc_addr_start + (i * req.max_j + j) * block_size.U val sp_addr = acc_addr_start + i val blocks = 1.U //Mux(j + max_blocks <= req.max_j, max_blocks, req.max_j-j) - val cols = block_size.U //(blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U) + val cols = WireInit(block_size.U) //(blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U) //val rows = block_size.U - Mux(i === req.max_i-1.U, req.pad_i, 0.U) - val rows = Mux(i + block_size.U > req.max_i, req.max_i - i, block_size.U) - //dontTouch(dram_addr) - //dontTouch(sp_addr) + val rows = WireInit(Mux(i + block_size.U > req.max_i, req.max_i - i, block_size.U)) + //val rows = WireInit(block_size.U) + //when(req.is_scale){ + when(req.pad_i =/= 0.U && i + block_size.U >= req.max_i && req.is_scale){ + rows := 1.U + cols := Mux(i === req.max_i - 1.U, block_size.U - req.pad_i, block_size.U) + } + + dontTouch(dram_addr) + dontTouch(sp_addr) + dontTouch(rows) + dontTouch(cols) val mvout_cmd = Wire(new RoCCCommand) mvout_cmd := DontCare @@ -621,6 +635,12 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth mvout_cmd_rs2.num_rows := rows.asUInt mvout_cmd_rs2.num_cols := cols.asUInt mvout_cmd_rs2.local_addr := cast_to_acc_addr(mvout_cmd_rs2.local_addr, sp_addr, accumulate = false.B, read_full = req.full_c) + /* + when(req.is_scale){ + mvout_cmd_rs2.local_addr := cast_to_sp_addr(mvout_cmd_rs2.local_addr, sp_addr) + } + + */ mvout_cmd.rs2 := mvout_cmd_rs2.asUInt io.req.ready := state === idle @@ -645,7 +665,9 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth (io.ex_k === req.max_k - 1.U && (io.ex_i >= i + block_size.U))) //(io.ex_i > i))) - + when(req.is_scale){ + ex_ahead := io.ex_completed || (io.ex_i >= i + block_size.U) + } io.cmd.valid := state =/= idle && !io.rob_overloaded && ex_ahead && req.dram_addr =/= 0.U io.cmd.bits := mvout_cmd @@ -655,8 +677,9 @@ class VegaLoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth state := idle }.elsewhen (io.cmd.fire() && state === st) { // The order here is k, j, i + val next_i_size = Mux(i + block_size.U >= req.max_i && req.pad_i =/= 0.U && req.is_scale, 1.U, block_size.U) //val next_i = floorAdd(i, 1.U, req.max_i) - val next_i = floorAdd(i, block_size.U, req.max_i) + val next_i = floorAdd(i, next_i_size, req.max_i) //val next_j = floorAdd(j, max_blocks, req.max_j, next_i === 0.U) i := next_i @@ -752,7 +775,7 @@ class VegaLoopMatmulState(val iterator_bitwidth: Int, val coreMaxAddrBits: Int, val a_addr_start = UInt(log2Up(max_addr).W) val b_addr_end = UInt(log2Up(max_addr+1).W) - val resadd_addr_start = UInt(log2Up(max_acc_addr).W) + val scale_addr_start = UInt(log2Up(max_acc_addr).W) def reset(): Unit = { configured := false.B @@ -808,7 +831,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ val loop_being_configured_id = Mux(head_loop.configured, tail_loop_id, head_loop_id) val loop_being_configured = loops(loop_being_configured_id) - //val is_resadd = RegInit(false.B) + val is_scale = RegInit(false.B) val max_all_addr = if(max_addr > max_acc_addr) max_addr else max_acc_addr // Create inner modules @@ -830,8 +853,8 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ val ab_loads_on_same_loop = ldA.io.loop_id === ldB.io.loop_id val forceA = !ab_loads_on_same_loop && ldA.io.loop_id === head_loop_id val forceB = !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id - ldab_arb.io.forceA := forceA //Mux(is_resadd, ab_loads_on_same_loop && !ldA.io.idle, forceA) - ldab_arb.io.forceB := forceB //Mux(is_resadd, forceB || ldA.io.idle, forceB) + ldab_arb.io.forceA := Mux(is_scale, false.B, forceA) //Mux(is_resadd, ab_loads_on_same_loop && !ldA.io.idle, forceA) + ldab_arb.io.forceB := Mux(is_scale, true.B, forceB) //Mux(is_resadd, forceB || ldA.io.idle, forceB) ldab_arb.io.weightA := 0.U ldab_arb.io.inA_idle := ldA.io.idle ldab_arb.io.inB_idle := ldB.io.idle @@ -909,15 +932,15 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ // when loop matmul is used as resadd unroller // skip ex // track ldB instead of ex - /* - when(is_resadd){ - stC.io.ex_completed := (ldA.io.loop_id =/= stC.io.loop_id || ldA.io.idle) && (ldB.io.loop_id =/= stC.io.loop_id || ldB.io.idle) + + when(is_scale){ + stC.io.ex_completed := (ldB.io.loop_id =/= stC.io.loop_id || ldB.io.idle)// && (ldB.io.loop_id =/= stC.io.loop_id || ldB.io.idle) stC.io.ex_k := 0.U // req.max_k shall be 1 //stC.io.ex_j := ldB.io.j stC.io.ex_i := ldB.io.k //ldB.io.rob_overloaded := ld_utilization >= max_lds.U || !((ldA.io.loop_id =/= ldB.io.loop_id) || ldA.io.idle) } - */ + val loops_configured = RegInit(0.U(16.W)) //dontTouch(loops_configured) @@ -966,7 +989,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ loop_being_configured.b_ex_spad_id := cmd.bits.cmd.rs1(17, 16) //loop_being_configured.a_transpose := cmd.bits.cmd.rs2(0) //loop_being_configured.b_transpose := cmd.bits.cmd.rs2(1) - //is_resadd := cmd.bits.cmd.rs2(2) + is_scale := cmd.bits.cmd.rs1(32) loop_being_configured.a_dram_stride := cmd.bits.cmd.rs2 loop_being_configured.configured := true.B @@ -992,7 +1015,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ //ldA.io.req.bits.transpose := loop_requesting_ldA.a_transpose ldA.io.req.bits.addr_start := Mux(loop_requesting_ldA.a_ex_spad_id === 0.U, loop_requesting_ldA.a_addr_start, (loop_requesting_ldA.a_ex_spad_id - 1.U) * (max_addr / concurrent_loops).U) ldA.io.req.bits.loop_id := loop_requesting_ldA_id - //ldA.io.req.bits.is_resadd := is_resadd + //ldA.io.req.bits.is_scale := is_scale ldA.io.req.valid := !loop_requesting_ldA.lda_started && loop_requesting_ldA.configured @@ -1012,7 +1035,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ //ldB.io.req.bits.transpose := loop_requesting_ldB.b_transpose ldB.io.req.bits.addr_end := Mux(loop_requesting_ldB.b_ex_spad_id === 0.U, loop_requesting_ldB.b_addr_end, (loop_requesting_ldB.b_ex_spad_id) * (max_addr / concurrent_loops).U) ldB.io.req.bits.loop_id := loop_requesting_ldB_id - //ldB.io.req.bits.is_resadd := is_resadd + ldB.io.req.bits.is_scale := is_scale ldB.io.req.valid := !loop_requesting_ldB.ldb_started && loop_requesting_ldB.configured @@ -1036,7 +1059,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ //ex.io.req.bits.b_tranpose := loop_requesting_ex.b_transpose ex.io.req.bits.c_addr_start := ex_c_addr_start ex.io.req.bits.loop_id := loop_requesting_ex_id - //ex.io.req.bits.skip := is_resadd + ex.io.req.bits.skip := is_scale ex.io.req.valid := !loop_requesting_ex.ex_started && loop_requesting_ex.lda_started && loop_requesting_ex.ldb_started && loop_requesting_ex.ldd_started && loop_requesting_ex.configured @@ -1075,18 +1098,18 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ val loop_requesting_st_id = Mux(head_loop.st_started, tail_loop_id, head_loop_id) val loop_requesting_st = loops(loop_requesting_st_id) - stC.io.req.bits.max_k := loop_requesting_st.max_k //Mux(is_resadd, 1.U, loop_requesting_st.max_k) + stC.io.req.bits.max_k := Mux(is_scale, loop_requesting_st.max_i, loop_requesting_st.max_k) //Mux(is_resadd, 1.U, loop_requesting_st.max_k) //stC.io.req.bits.max_j := loop_requesting_st.max_j - stC.io.req.bits.max_i := loop_requesting_st.max_i + stC.io.req.bits.max_i := Mux(is_scale, loop_requesting_st.max_k, loop_requesting_st.max_i) //stC.io.req.bits.pad_j := loop_requesting_st.pad_j - stC.io.req.bits.pad_i := loop_requesting_st.pad_i + stC.io.req.bits.pad_i := loop_requesting_st.pad_i //Mux(is_scale, loop_requesting_st.pad_k, loop_requesting_st.pad_i) stC.io.req.bits.dram_addr := loop_requesting_st.c_dram_addr //stC.io.req.bits.dram_stride := loop_requesting_st.c_dram_stride stC.io.req.bits.full_c := loop_requesting_st.full_c stC.io.req.bits.act := loop_requesting_st.act stC.io.req.bits.addr_start := st_c_addr_start stC.io.req.bits.loop_id := loop_requesting_st_id - //stC.io.req.bits.is_resadd := is_resadd + stC.io.req.bits.is_scale := is_scale stC.io.req.valid := !loop_requesting_st.st_started && loop_requesting_st.ex_started && loop_requesting_st.configured @@ -1099,15 +1122,15 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ st_c_addr_start := floorAdd(st_c_addr_start, (max_acc_addr / concurrent_loops).U, max_acc_addr.U) } } -/* - when(is_resadd){ - ldA.io.req.bits.addr_start := loop_requesting_ldA.resadd_addr_start - ldB.io.req.bits.addr_end := loop_requesting_ldB.resadd_addr_start - stC.io.req.bits.addr_start := loop_requesting_st.resadd_addr_start + + when(is_scale){ + //ldA.io.req.bits.addr_start := loop_requesting_ldA.scale_addr_start + ldB.io.req.bits.addr_end := loop_requesting_ldB.scale_addr_start + stC.io.req.bits.addr_start := loop_requesting_st.scale_addr_start stC.io.req.valid := !loop_requesting_st.st_started && loop_requesting_st.configured } - */ + // Handle completed signals when (ldA.io.idle && loops(ldA.io.loop_id).running && loops(ldA.io.loop_id).lda_started) { loops(ldA.io.loop_id).lda_completed := true.B @@ -1140,7 +1163,7 @@ class VegaLoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_ l.reset() l.a_addr_start := (i * (max_addr / concurrent_loops)).U l.b_addr_end := ((i+1) * (max_addr / concurrent_loops)).U - l.resadd_addr_start := (i * (max_acc_addr / concurrent_loops)).U + l.scale_addr_start := (i * (max_acc_addr / concurrent_loops)).U } } }