Skip to content

Commit

Permalink
conv fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Seah Kim authored and Seah Kim committed Oct 25, 2023
1 parent 69456b0 commit 7934398
Showing 1 changed file with 57 additions and 45 deletions.
102 changes: 57 additions & 45 deletions src/main/scala/gemmini/LoopConv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class LoopConvOuterBounds(val large_iterator_bitwidth: Int, val small_iterator_b
val in_col_dim = UInt(small_iterator_bitwidth.W)
val in_channels = UInt(large_iterator_bitwidth.W)
val out_channels = UInt(large_iterator_bitwidth.W)
val out_col_dim = UInt(large_iterator_bitwidth.W)
val out_row_dim = UInt(large_iterator_bitwidth.W)
val out_col_dim = UInt(small_iterator_bitwidth.W)
val out_row_dim = UInt(small_iterator_bitwidth.W)
val out_stride = UInt(large_iterator_bitwidth.W) //stride for output activation
val in_stride = UInt(large_iterator_bitwidth.W) //stride for input activation
val weight_stride = UInt(large_iterator_bitwidth.W) //stride for weight
Expand Down Expand Up @@ -77,7 +77,7 @@ class LoopConvLdBiasReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth: I
val derived_params = new LoopConvDerivedParams(large_iterator_bitwidth, small_iterator_bitwidth, tiny_iterator_bitwidth)
val addr_start = UInt(log2Up(max_acc_addr).W)
val dram_addr = UInt(coreMaxAddrBits.W)
val no_bias = Bool()
//val no_bias = Bool()
val loop_id = UInt(log2Up(concurrent_loops).W)
}

Expand Down Expand Up @@ -121,7 +121,7 @@ class LoopConvLdBias(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitwi

// Addresses
val dram_offset = och * (acc_w/8).U
val dram_addr = Mux(req.no_bias, 0.U, req.dram_addr + LoopConv.castDramOffset(dram_offset))
val dram_addr = Mux(skip, 0.U, req.dram_addr + LoopConv.castDramOffset(dram_offset))
//val spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * batches * orows * ocols +& b * orows * ocols +& orow * ocols +& ocol
val spad_addr = acc_addr_start +& (och / block_size.U(och.getWidth.W)) * orows * ocols +& orow * ocols +& ocol

Expand Down Expand Up @@ -473,8 +473,8 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit

// Iterators
val och = Reg(UInt(large_iterator_bitwidth.W))
//val krow = Reg(UInt(tiny_iterator_bitwidth.W))
//val kcol = Reg(UInt(tiny_iterator_bitwidth.W))
val krow = Reg(UInt(tiny_iterator_bitwidth.W))
val kcol = Reg(UInt(tiny_iterator_bitwidth.W))
val kch = Reg(UInt(large_iterator_bitwidth.W))

// Addresses
Expand All @@ -485,7 +485,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
req.trans_weight_0132 -> (((krow*kernel_dim*out_channels +& kcol*out_channels +& och) * in_channels +& kch) * (input_w/8).U)
))
*/
val dram_offset = Mux(req.dw, 0.U, ((kch * weight_stride +& och) * (input_w/8).U))
val dram_offset = Mux(req.dw, (krow * kernel_dim +& kcol) * (input_w/8).U, ((krow*kernel_dim*in_channels +& kcol*in_channels +& kch) * weight_stride +& och) * (input_w/8).U)

val dram_addr = req.dram_addr + LoopConv.castDramOffset(dram_offset)

Expand All @@ -495,7 +495,7 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
addr_start + (kch / block_size.U(kch.getWidth.W)) * krows * kcols * ochs + krow * kcols * ochs + kcol * ochs + och,
addr_start + (och / block_size.U(och.getWidth.W)) * krows * kcols * kchs + krow * kcols * kchs + kcol * kchs + kch)
*/
val spad_addr = addr_start + (och / block_size.U(och.getWidth.W)) * kernel_dim * kernel_dim * kchs + kch
val spad_addr = addr_start + (och / block_size.U(och.getWidth.W)) * kernel_dim * kernel_dim * kchs + krow * kernel_dim * kchs + kcol * kchs + kch

// Sizes
/*
Expand Down Expand Up @@ -574,20 +574,22 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
when (state === config) {
state := ld
}.otherwise {
val och_it = Mux(req.trans_weight_0132, block_size.U, max_chs_per_mvin)
val kch_it = Mux(req.trans_weight_0132, max_chs_per_mvin, block_size.U)
//val och_it = Mux(req.trans_weight_0132, block_size.U, max_chs_per_mvin)
//val kch_it = Mux(req.trans_weight_0132, max_chs_per_mvin, block_size.U)
val och_it = max_chs_per_mvin
val kch_it = block_size.U

val next_kch = floorAdd(kch, kch_it, kchs)
//val next_kcol = floorAdd(kcol, 1.U, kcols, next_kch === 0.U)
//val next_krow = floorAdd(krow, 1.U, krows, next_kcol === 0.U && next_kch === 0.U)
val next_och = floorAdd(och, och_it, ochs, next_kch === 0.U)
val next_kcol = floorAdd(kcol, 1.U, kernel_dim, next_kch === 0.U)
val next_krow = floorAdd(krow, 1.U, kernel_dim, next_kcol === 0.U && next_kch === 0.U)
val next_och = floorAdd(och, och_it, ochs, next_krow === 0.U && next_kcol === 0.U && next_kch === 0.U)

kch := next_kch
//kcol := next_kcol
//krow := next_krow
kcol := next_kcol
krow := next_krow
och := next_och

state := Mux(next_och === 0.U && next_kch === 0.U,
state := Mux(next_och === 0.U && next_kch === 0.U && next_krow === 0.U && next_kcol === 0.U,
idle, ld)
}
}
Expand All @@ -597,6 +599,8 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
req := io.req.bits
state := config
kch := 0.U
kcol := 0.U
krow := 0.U
och := 0.U
}
}
Expand All @@ -614,6 +618,7 @@ class LoopConvExecuteReq(val large_iterator_bitwidth: Int, val small_iterator_bi
val input_dilated = Bool()
val trans_weight_0132 = Bool()
val trans_input_3120 = Bool()
val accumulate = Bool()
val loop_id = UInt(log2Up(concurrent_loops).W)
}

Expand Down Expand Up @@ -660,24 +665,24 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera

// Iterators
val och = Reg(UInt(large_iterator_bitwidth.W))
//val krow = Reg(UInt(tiny_iterator_bitwidth.W))
//val kcol = Reg(UInt(tiny_iterator_bitwidth.W))
val krow = Reg(UInt(tiny_iterator_bitwidth.W))
val kcol = Reg(UInt(tiny_iterator_bitwidth.W))
val kch = Reg(UInt(large_iterator_bitwidth.W))
//val b = Reg(UInt(large_iterator_bitwidth.W))
val orow = Reg(UInt(small_iterator_bitwidth.W))
val ocol = Reg(UInt(small_iterator_bitwidth.W))

// TODO kernel-dilation and input-dilation can never be activated at the same time, so we can optimize out some multiplications by kernel_dilation
//val skip_iteration = state >= pre && req.input_dilated && (((krow * kernel_dilation +& orow -& upad)(0) & req.input_dilated).asBool ||
// ((kcol * kernel_dilation +& ocol -& lpad)(0) & req.input_dilated).asBool)
val skip_iteration = state >= pre && req.input_dilated && (((orow -& upad)(0) & req.input_dilated).asBool ||
((ocol -& lpad)(0) & req.input_dilated).asBool)
val skip_iteration = state >= pre && req.input_dilated && (((krow * kernel_dilation +& orow -& upad)(0) & req.input_dilated).asBool ||
((kcol * kernel_dilation +& ocol -& lpad)(0) & req.input_dilated).asBool)
//val skip_iteration = state >= pre && req.input_dilated && (((orow -& upad)(0) & req.input_dilated).asBool ||
// ((ocol -& lpad)(0) & req.input_dilated).asBool)

//val pixels = Mux(kcols - kcol > req.max_pixels_per_row, req.max_pixels_per_row, kcols - kcol)
val pixels = Mux(kernel_dim > req.max_pixels_per_row, req.max_pixels_per_row, kernel_dim)
val pixels = Mux(kernel_dim - kcol > req.max_pixels_per_row, req.max_pixels_per_row, kernel_dim - kcol)
//val pixels = Mux(kernel_dim > req.max_pixels_per_row, req.max_pixels_per_row, kernel_dim)

val irow = undilated(orow * stride)// +& krow * kernel_dilation)
val icol = undilated(ocol * stride)// +& kcol * kernel_dilation)
val irow = undilated(orow * stride +& krow * kernel_dilation)
val icol = undilated(ocol * stride +& kcol * kernel_dilation)

/*
val I = Mux(req.trans_input_3120,
Expand Down Expand Up @@ -705,15 +710,15 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera
// val new_weights = b === 0.U && orow === 0.U && ocol === 0.U
val new_weights = Reg(Bool())
// these will be krow/kcol (which is 0)
//val krow_rot = Mux(req.wrot180, krows - krow - 1.U, krow)
//val kcol_rot = Mux(req.wrot180, kcols - kcol - 1.U, kcol)
val krow_rot = Mux(req.wrot180, kernel_dim - krow - 1.U, krow)
val kcol_rot = Mux(req.wrot180, kernel_dim - kcol - 1.U, kcol)

/*
val b_addr = Mux(req.trans_weight_0132,
b_addr_start +& (kch / block_size.U(och.getWidth.W)) * krows * kcols * ochs +& krow_rot * kcols * ochs +& kcol_rot * ochs +& och,
b_addr_start +& (och / block_size.U(och.getWidth.W)) * krows * kcols * kchs +& krow_rot * kcols * kchs +& kcol_rot * kchs +& kch)
*/
val b_addr = b_addr_start +& (och / block_size.U(och.getWidth.W)) * kernel_dim * kernel_dim * kchs +& kch
val b_addr = b_addr_start +& (och / block_size.U(och.getWidth.W)) * kernel_dim * kernel_dim * kchs +& krow_rot * kernel_dim * kchs +& kcol_rot * kchs +& kch

class RoCCCommandWithAddr extends Bundle {
val cmd = new RoCCCommand
Expand All @@ -724,6 +729,7 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera
val J = UInt()
val K = UInt()
val new_weights = Bool()
val k_start = Bool()
}
val command_p = Module(new Pipeline[RoCCCommandWithAddr](new RoCCCommandWithAddr, latency)())

Expand Down Expand Up @@ -773,10 +779,12 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera
command_p.io.in.bits.J := J
command_p.io.in.bits.K := K
command_p.io.in.bits.new_weights := new_weights
command_p.io.in.bits.k_start := kch === 0.U && kcol === 0.U && krow === 0.U

command_p.io.out.ready := io.cmd.ready && !io.rob_overloaded
io.cmd.valid := command_p.io.out.valid && !io.rob_overloaded
io.cmd.bits := command_p.io.out.bits.cmd
val k_start = command_p.io.out.bits.k_start
when (command_p.io.out.bits.cmd.inst.funct === PRELOAD_CMD) {
val o = command_p.io.out.bits

Expand All @@ -791,7 +799,8 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera
pre_cmd_rs2 := DontCare
pre_cmd_rs2.num_rows := o.I.asUInt
pre_cmd_rs2.num_cols := o.J.asUInt
pre_cmd_rs2.local_addr := cast_to_acc_addr(pre_cmd_rs2.local_addr, o.c_addr, accumulate = true.B, read_full = false.B)
pre_cmd_rs2.local_addr := cast_to_acc_addr(pre_cmd_rs2.local_addr, o.c_addr, accumulate = req.accumulate || !k_start, read_full = false.B)
//pre_cmd_rs2.local_addr := cast_to_acc_addr(pre_cmd_rs2.local_addr, o.c_addr, accumulate = true.B, read_full = false.B)

io.cmd.bits.rs1 := pre_cmd_rs1.asUInt
io.cmd.bits.rs2 := pre_cmd_rs2.asUInt
Expand Down Expand Up @@ -826,32 +835,34 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera
state := comp
}.otherwise {
//val b_it = Mux(req.trans_input_3120, block_size.U, 1.U)
val ocol_it = Mux(skip_iteration || req.trans_input_3120, 1.U, block_size.U << req.input_dilated).asUInt
//val ocol_it = Mux(skip_iteration || req.trans_input_3120, 1.U, block_size.U << req.input_dilated).asUInt
val ocol_it = Mux(skip_iteration, 1.U, block_size.U << req.input_dilated).asUInt

val next_ocol = floorAdd(ocol, ocol_it, ocols)
val next_orow = floorAdd(orow, 1.U, orows, next_ocol === 0.U)
//val next_b = floorAdd(b, b_it, batches, next_orow === 0.U && next_ocol === 0.U)
val next_kch = floorAdd(kch, block_size.U, kchs,
next_orow === 0.U && next_ocol === 0.U)
//val next_kcol = floorAdd(kcol, req.max_pixels_per_row, kcols,
// next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U)
//val next_krow = floorAdd(krow, 1.U, krows,
// next_kcol === 0.U && next_kch === 0.U && next_b === 0.U && next_orow === 0.U && next_ocol === 0.U)
val next_och = floorAdd(och, block_size.U, ochs, next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U)
val next_kcol = floorAdd(kcol, req.max_pixels_per_row, kernel_dim,
next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U)
val next_krow = floorAdd(krow, 1.U, kernel_dim,
next_kcol === 0.U && next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U)
val next_och = floorAdd(och, block_size.U, ochs, next_krow === 0.U &&
next_kcol === 0.U && next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U)

ocol := next_ocol
orow := next_orow
//b := next_b
kch := next_kch
//kcol := next_kcol
//krow := next_krow
kcol := next_kcol
krow := next_krow
och := next_och

when (next_orow === 0.U && next_ocol === 0.U) {
new_weights := true.B
}

state := Mux(next_och === 0.U && next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U, idle, pre)
state := Mux(next_och === 0.U && next_kch === 0.U && next_orow === 0.U && next_ocol === 0.U && next_krow === 0.U && next_kcol === 0.U, idle, pre)
}
}

Expand All @@ -865,8 +876,8 @@ class LoopConvExecute(block_size: Int, large_iterator_bitwidth: Int, small_itera
orow := 0.U
ocol := 0.U
och := 0.U
//krow := 0.U
//kcol := 0.U
krow := 0.U
kcol := 0.U
kch := 0.U

new_weights := true.B
Expand Down Expand Up @@ -1108,7 +1119,7 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s
val input_dram_addr = UInt(coreMaxAddrBits.W)
val output_dram_addr = UInt(coreMaxAddrBits.W)

val no_bias = Bool()
val accumulate = Bool()
val wrot180 = Bool()
val no_pool = Bool()
val downsample = Bool()
Expand Down Expand Up @@ -1222,7 +1233,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
has_dw_convs: Boolean)
(implicit p: Parameters) extends Module {
val large_iterator_bitwidth = 16
val small_iterator_bitwidth = 12
val small_iterator_bitwidth = 10
val tiny_iterator_bitwidth = 4

val max_block_len = (dma_max_bytes / (block_size * (input_w / 8))) max 1
Expand Down Expand Up @@ -1391,7 +1402,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
}

is (LOOP_CONV_WS) {
loop_being_configured.no_bias := cmd.bits.cmd.rs1(0)
loop_being_configured.accumulate := cmd.bits.cmd.rs1(0)

// TODO we added a default value for max_pixels_per_row just to maintain backwards compatibility. we should deprecate and remove it later
val config_max_pixels_per_row = cmd.bits.cmd.rs1(15, 8)
Expand Down Expand Up @@ -1435,7 +1446,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
ld_bias.io.req.bits.derived_params := loop_requesting_ld_bias.derived_params()
ld_bias.io.req.bits.addr_start := ld_bias_addr_start
ld_bias.io.req.bits.dram_addr := loop_requesting_ld_bias.bias_dram_addr
ld_bias.io.req.bits.no_bias := loop_requesting_ld_bias.no_bias
//ld_bias.io.req.bits.no_bias := loop_requesting_ld_bias.no_bias
ld_bias.io.req.bits.loop_id := loop_requesting_ld_bias_id

ld_bias.io.req.valid := !loop_requesting_ld_bias.ld_bias_started && loop_requesting_ld_bias.configured
Expand Down Expand Up @@ -1504,6 +1515,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
ex.io.req.bits.trans_weight_0132 := loop_requesting_ex.trans_weight_0132
ex.io.req.bits.trans_input_3120 := loop_requesting_ex.trans_input_3120
ex.io.req.bits.loop_id := loop_requesting_ex_id
ex.io.req.bits.accumulate := loop_requesting_ex.accumulate

ex.io.req.valid := !loop_requesting_ex.ex_started && loop_requesting_ex.ld_bias_started &&
loop_requesting_ex.ld_input_started && loop_requesting_ex.ld_weights_started && loop_requesting_ex.configured
Expand Down

0 comments on commit 7934398

Please sign in to comment.