From 79343989d517d2229783699fadafeca18956bb27 Mon Sep 17 00:00:00 2001 From: Seah Kim Date: Wed, 25 Oct 2023 02:55:10 -0700 Subject: [PATCH] conv fix --- src/main/scala/gemmini/LoopConv.scala | 102 ++++++++++++++------------ 1 file changed, 57 insertions(+), 45 deletions(-) diff --git a/src/main/scala/gemmini/LoopConv.scala b/src/main/scala/gemmini/LoopConv.scala index a7ae0d6e..c0d579a9 100644 --- a/src/main/scala/gemmini/LoopConv.scala +++ b/src/main/scala/gemmini/LoopConv.scala @@ -16,8 +16,8 @@ class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_b val in_col_dim = UInt(small_iterator_bitwidth.W) val in_channels = UInt(large_iterator_bitwidth.W) val out_channels = UInt(large_iterator_bitwidth.W) - val out_col_dim = UInt(large_iterator_bitwidth.W) - val out_row_dim = UInt(large_iterator_bitwidth.W) + val out_col_dim = UInt(small_iterator_bitwidth.W) + val out_row_dim = UInt(small_iterator_bitwidth.W) val out_stride = UInt(large_iterator_bitwidth.W) //stride for output activation val in_stride = UInt(large_iterator_bitwidth.W) //stride for input activation val weight_stride = UInt(large_iterator_bitwidth.W) //stride for weight @@ -77,7 +77,7 @@ class LoopConvLdBiasReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: I val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth) val addr_start = UInt(log2Up(max_acc_addr).W) val dram_addr = UInt(coreMaxAddrBits.W) - val no_bias = Bool() + //val no_bias = Bool() val loop_id = UInt(log2Up(concurrent_loops).W) } @@ -121,7 +121,7 @@ class LoopConvLdBias(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwi // Addresses val dram_offset = och * (acc_w/8).U - val dram_addr = Mux(req.no_bias, 0.U, req.dram_addr + LoopConv.castDramOffset(dram_offset)) + val dram_addr = Mux(skip, 0.U, req.dram_addr + LoopConv.castDramOffset(dram_offset)) //val spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol val spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * orows * ocols +& orow * ocols +& ocol @@ -473,8 +473,8 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit // Iterators val och = Reg(UInt(large_iterator_bitwidth.W)) - //val krow = Reg(UInt(tiny_iterator_bitwidth.W)) - //val kcol = Reg(UInt(tiny_iterator_bitwidth.W)) + val krow = Reg(UInt(tiny_iterator_bitwidth.W)) + val kcol = Reg(UInt(tiny_iterator_bitwidth.W)) val kch = Reg(UInt(large_iterator_bitwidth.W)) // Addresses @@ -485,7 +485,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit req.trans_weight_0132 -> (((krow*kernel_dim*out_channels +& kcol*out_channels +& och) * in_channels +& kch) * (input_w/8).U) )) */ - val dram_offset = Mux(req.dw, 0.U, ((kch * weight_stride +& och) * (input_w/8).U)) + val dram_offset = Mux(req.dw, (krow * kernel_dim +& kcol) * (input_w/8).U, ((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * weight_stride +& och) * (input_w/8).U) val dram_addr = req.dram_addr + LoopConv.castDramOffset(dram_offset) @@ -495,7 +495,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit addr_start + (kch / block_size.U(kch.getWidth.W)) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och, addr_start + (och / block_size.U(och.getWidth.W)) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch) */ - val spad_addr = addr_start + (och / block_size.U(och.getWidth.W)) * kernel_dim * kernel_dim * kchs + kch + val spad_addr = addr_start + (och / block_size.U(och.getWidth.W)) * kernel_dim * kernel_dim * kchs + krow * kernel_dim * kchs + kcol * kchs + kch // Sizes /* @@ -574,20 +574,22 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit when (state === config) { state := ld }.otherwise { - val och_it = Mux(req.trans_weight_0132, block_size.U, max_chs_per_mvin) - val kch_it = Mux(req.trans_weight_0132, max_chs_per_mvin, block_size.U) + //val och_it = Mux(req.trans_weight_0132, block_size.U, max_chs_per_mvin) + //val kch_it = Mux(req.trans_weight_0132, max_chs_per_mvin, block_size.U) + val och_it = max_chs_per_mvin + val kch_it = block_size.U val next_kch = floorAdd(kch, kch_it, kchs) - //val next_kcol = floorAdd(kcol, 1.U, kcols, next_kch === 0.U) - //val next_krow = floorAdd(krow, 1.U, krows, next_kcol === 0.U && next_kch === 0.U) - val next_och = floorAdd(och, och_it, ochs, next_kch === 0.U) + val next_kcol = floorAdd(kcol, 1.U, kernel_dim, next_kch === 0.U) + val next_krow = floorAdd(krow, 1.U, kernel_dim, next_kcol === 0.U && next_kch === 0.U) + val next_och = floorAdd(och, och_it, ochs, next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U) kch := next_kch - //kcol := next_kcol - //krow := next_krow + kcol := next_kcol + krow := next_krow och := next_och - state := Mux(next_och === 0.U && next_kch === 0.U, + state := Mux(next_och === 0.U && next_kch === 0.U && next_krow === 0.U && next_kcol === 0.U, idle, ld) } } @@ -597,6 +599,8 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit req := io.req.bits state := config kch := 0.U + kcol := 0.U + krow := 0.U och := 0.U } } @@ -614,6 +618,7 @@ class LoopConvExecuteReq(val large_iterator_bitwidth: Int, val small_iterator_bi val input_dilated = Bool() val trans_weight_0132 = Bool() val trans_input_3120 = Bool() + val accumulate = Bool() val loop_id = UInt(log2Up(concurrent_loops).W) } @@ -660,24 +665,24 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera // Iterators val och = Reg(UInt(large_iterator_bitwidth.W)) - //val krow = Reg(UInt(tiny_iterator_bitwidth.W)) - //val kcol = Reg(UInt(tiny_iterator_bitwidth.W)) + val krow = Reg(UInt(tiny_iterator_bitwidth.W)) + val kcol = Reg(UInt(tiny_iterator_bitwidth.W)) val kch = Reg(UInt(large_iterator_bitwidth.W)) //val b = Reg(UInt(large_iterator_bitwidth.W)) val orow = Reg(UInt(small_iterator_bitwidth.W)) val ocol = Reg(UInt(small_iterator_bitwidth.W)) // TODO kernel-dilation and input-dilation can never be activated at the same time, so we can optimize out some multiplications by kernel_dilation - //val skip_iteration = state >= pre && req.input_dilated && (((krow * kernel_dilation +& orow -& upad)(0) & req.input_dilated).asBool || - // ((kcol * kernel_dilation +& ocol -& lpad)(0) & req.input_dilated).asBool) - val skip_iteration = state >= pre && req.input_dilated && (((orow -& upad)(0) & req.input_dilated).asBool || - ((ocol -& lpad)(0) & req.input_dilated).asBool) + val skip_iteration = state >= pre && req.input_dilated && (((krow * kernel_dilation +& orow -& upad)(0) & req.input_dilated).asBool || + ((kcol * kernel_dilation +& ocol -& lpad)(0) & req.input_dilated).asBool) + //val skip_iteration = state >= pre && req.input_dilated && (((orow -& upad)(0) & req.input_dilated).asBool || + // ((ocol -& lpad)(0) & req.input_dilated).asBool) - //val pixels = Mux(kcols - kcol > req.max_pixels_per_row, req.max_pixels_per_row, kcols - kcol) - val pixels = Mux(kernel_dim > req.max_pixels_per_row, req.max_pixels_per_row, kernel_dim) + val pixels = Mux(kernel_dim - kcol > req.max_pixels_per_row, req.max_pixels_per_row, kernel_dim - kcol) + //val pixels = Mux(kernel_dim > req.max_pixels_per_row, req.max_pixels_per_row, kernel_dim) - val irow = undilated(orow * stride)// +& krow * kernel_dilation) - val icol = undilated(ocol * stride)// +& kcol * kernel_dilation) + val irow = undilated(orow * stride +& krow * kernel_dilation) + val icol = undilated(ocol * stride +& kcol * kernel_dilation) /* val I = Mux(req.trans_input_3120, @@ -705,15 +710,15 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera // val new_weights = b === 0.U && orow === 0.U && ocol === 0.U val new_weights = Reg(Bool()) // these will be krow/kcol (which is 0) - //val krow_rot = Mux(req.wrot180, krows - krow - 1.U, krow) - //val kcol_rot = Mux(req.wrot180, kcols - kcol - 1.U, kcol) + val krow_rot = Mux(req.wrot180, kernel_dim - krow - 1.U, krow) + val kcol_rot = Mux(req.wrot180, kernel_dim - kcol - 1.U, kcol) /* val b_addr = Mux(req.trans_weight_0132, b_addr_start +& (kch / block_size.U(och.getWidth.W)) * krows * kcols * ochs +& krow_rot * kcols * ochs +& kcol_rot * ochs +& och, b_addr_start +& (och / block_size.U(och.getWidth.W)) * krows * kcols * kchs +& krow_rot * kcols * kchs +& kcol_rot * kchs +& kch) */ - val b_addr = b_addr_start +& (och / block_size.U(och.getWidth.W)) * kernel_dim * kernel_dim * kchs +& kch + val b_addr = b_addr_start +& (och / block_size.U(och.getWidth.W)) * kernel_dim * kernel_dim * kchs +& krow_rot * kernel_dim * kchs +& kcol_rot * kchs +& kch class RoCCCommandWithAddr extends Bundle { val cmd = new RoCCCommand @@ -724,6 +729,7 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera val J = UInt() val K = UInt() val new_weights = Bool() + val k_start = Bool() } val command_p = Module(new Pipeline[RoCCCommandWithAddr](new RoCCCommandWithAddr, latency)()) @@ -773,10 +779,12 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera command_p.io.in.bits.J := J command_p.io.in.bits.K := K command_p.io.in.bits.new_weights := new_weights + command_p.io.in.bits.k_start := kch === 0.U && kcol === 0.U && krow === 0.U command_p.io.out.ready := io.cmd.ready && !io.rob_overloaded io.cmd.valid := command_p.io.out.valid && !io.rob_overloaded io.cmd.bits := command_p.io.out.bits.cmd + val k_start = command_p.io.out.bits.k_start when (command_p.io.out.bits.cmd.inst.funct === PRELOAD_CMD) { val o = command_p.io.out.bits @@ -791,7 +799,8 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera pre_cmd_rs2 := DontCare pre_cmd_rs2.num_rows := o.I.asUInt pre_cmd_rs2.num_cols := o.J.asUInt - pre_cmd_rs2.local_addr := cast_to_acc_addr(pre_cmd_rs2.local_addr, o.c_addr, accumulate = true.B, read_full = false.B) + pre_cmd_rs2.local_addr := cast_to_acc_addr(pre_cmd_rs2.local_addr, o.c_addr, accumulate = req.accumulate || !k_start, read_full = false.B) + //pre_cmd_rs2.local_addr := cast_to_acc_addr(pre_cmd_rs2.local_addr, o.c_addr, accumulate = true.B, read_full = false.B) io.cmd.bits.rs1 := pre_cmd_rs1.asUInt io.cmd.bits.rs2 := pre_cmd_rs2.asUInt @@ -826,32 +835,34 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera state := comp }.otherwise { //val b_it = Mux(req.trans_input_3120, block_size.U, 1.U) - val ocol_it = Mux(skip_iteration || req.trans_input_3120, 1.U, block_size.U << req.input_dilated).asUInt + //val ocol_it = Mux(skip_iteration || req.trans_input_3120, 1.U, block_size.U << req.input_dilated).asUInt + val ocol_it = Mux(skip_iteration, 1.U, block_size.U << req.input_dilated).asUInt val next_ocol = floorAdd(ocol, ocol_it, ocols) val next_orow = floorAdd(orow, 1.U, orows, next_ocol === 0.U) //val next_b = floorAdd(b, b_it, batches, next_orow === 0.U && next_ocol === 0.U) val next_kch = floorAdd(kch, block_size.U, kchs, next_orow === 0.U && next_ocol === 0.U) - //val next_kcol = floorAdd(kcol, req.max_pixels_per_row, kcols, - // next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) - //val next_krow = floorAdd(krow, 1.U, krows, - // next_kcol === 0.U && next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U) - val next_och = floorAdd(och, block_size.U, ochs, next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U) + val next_kcol = floorAdd(kcol, req.max_pixels_per_row, kernel_dim, + next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U) + val next_krow = floorAdd(krow, 1.U, kernel_dim, + next_kcol === 0.U && next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U) + val next_och = floorAdd(och, block_size.U, ochs, next_krow === 0.U && + next_kcol === 0.U && next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U) ocol := next_ocol orow := next_orow //b := next_b kch := next_kch - //kcol := next_kcol - //krow := next_krow + kcol := next_kcol + krow := next_krow och := next_och when (next_orow === 0.U && next_ocol === 0.U) { new_weights := true.B } - state := Mux(next_och === 0.U && next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U, idle, pre) + state := Mux(next_och === 0.U && next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U && next_krow === 0.U && next_kcol === 0.U, idle, pre) } } @@ -865,8 +876,8 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera orow := 0.U ocol := 0.U och := 0.U - //krow := 0.U - //kcol := 0.U + krow := 0.U + kcol := 0.U kch := 0.U new_weights := true.B @@ -1108,7 +1119,7 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s val input_dram_addr = UInt(coreMaxAddrBits.W) val output_dram_addr = UInt(coreMaxAddrBits.W) - val no_bias = Bool() + val accumulate = Bool() val wrot180 = Bool() val no_pool = Bool() val downsample = Bool() @@ -1222,7 +1233,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: has_dw_convs: Boolean) (implicit p: Parameters) extends Module { val large_iterator_bitwidth = 16 - val small_iterator_bitwidth = 12 + val small_iterator_bitwidth = 10 val tiny_iterator_bitwidth = 4 val max_block_len = (dma_max_bytes / (block_size * (input_w / 8))) max 1 @@ -1391,7 +1402,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: } is (LOOP_CONV_WS) { - loop_being_configured.no_bias := cmd.bits.cmd.rs1(0) + loop_being_configured.accumulate := cmd.bits.cmd.rs1(0) // TODO we added a default value for max_pixels_per_row just to maintain backwards compatibility. we should deprecate and remove it later val config_max_pixels_per_row = cmd.bits.cmd.rs1(15, 8) @@ -1435,7 +1446,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: ld_bias.io.req.bits.derived_params := loop_requesting_ld_bias.derived_params() ld_bias.io.req.bits.addr_start := ld_bias_addr_start ld_bias.io.req.bits.dram_addr := loop_requesting_ld_bias.bias_dram_addr - ld_bias.io.req.bits.no_bias := loop_requesting_ld_bias.no_bias + //ld_bias.io.req.bits.no_bias := loop_requesting_ld_bias.no_bias ld_bias.io.req.bits.loop_id := loop_requesting_ld_bias_id ld_bias.io.req.valid := !loop_requesting_ld_bias.ld_bias_started && loop_requesting_ld_bias.configured @@ -1504,6 +1515,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size: ex.io.req.bits.trans_weight_0132 := loop_requesting_ex.trans_weight_0132 ex.io.req.bits.trans_input_3120 := loop_requesting_ex.trans_input_3120 ex.io.req.bits.loop_id := loop_requesting_ex_id + ex.io.req.bits.accumulate := loop_requesting_ex.accumulate ex.io.req.valid := !loop_requesting_ex.ex_started && loop_requesting_ex.ld_bias_started && loop_requesting_ex.ld_input_started && loop_requesting_ex.ld_weights_started && loop_requesting_ex.configured