Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for non-zero padding in instructions mvin and loop_conv_ws #274

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/main/scala/gemmini/GemminiConfigs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ case class GemminiArrayConfig[T <: Data : Arithmetic, U <: Data, V <: Data](

val pixel_repeats_bits = 8 min log2Up(meshColumns * tileColumns + 1)

// Note: put here the min with 16 because there are only 16 bits left in the config mvin instruction
val padding_value_bits = inputType.getWidth min 16

val hasIm2Col = false

val tree_reduction = use_tree_reduction_if_possible && dataflow == Dataflow.WS && tileRows > 1
Expand Down
7 changes: 6 additions & 1 deletion src/main/scala/gemmini/LoadController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig
val control_state = RegInit(waiting_for_command)

val strides = Reg(Vec(load_states, UInt(coreMaxAddrBits.W)))
val padding_values = Reg(Vec(load_states, UInt(padding_value_bits.W)))
val scales = Reg(Vec(load_states, UInt(mvin_scale_t_bits.W)))
val shrinks = Reg(Vec(load_states, Bool())) // Shrink inputs to accumulator
val block_strides = Reg(Vec(load_states, UInt(block_stride_bits.W))) // Spad stride during block move-ins
Expand All @@ -47,10 +48,11 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig
val cols = mvin_rs2.num_cols
val rows = mvin_rs2.num_rows

val config_stride = cmd.bits.cmd.rs2
val config_stride = cmd.bits.cmd.rs2(31,0)

val config_mvin_rs1 = cmd.bits.cmd.rs1.asTypeOf(new ConfigMvinRs1(mvin_scale_t_bits, block_stride_bits, pixel_repeats_bits))

val config_padding_value = cmd.bits.cmd.rs2(63,32)
val config_scale = config_mvin_rs1.scale
val config_shrink = config_mvin_rs1.shrink
val config_block_stride = config_mvin_rs1.stride
Expand All @@ -63,6 +65,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig
val config_state_id = config_mvin_rs1.state_id
val state_id = Mux(cmd.bits.cmd.inst.funct === CONFIG_CMD, config_state_id, load_state_id)

val padding_value = padding_values(state_id)
val stride = strides(state_id)
val scale = scales(state_id)
val shrink = shrinks(state_id)
Expand Down Expand Up @@ -109,6 +112,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig
io.dma.req.bits.all_zeros := all_zeros
io.dma.req.bits.status := mstatus
io.dma.req.bits.pixel_repeats := pixel_repeat
io.dma.req.bits.padding_value := padding_value

// Command tracker IO
cmd_tracker.io.alloc.valid := control_state === waiting_for_command && cmd.valid && DoLoad
Expand Down Expand Up @@ -141,6 +145,7 @@ class LoadController[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig
is (waiting_for_command) {
when (cmd.valid) {
when(DoConfig) {
padding_value := config_padding_value
stride := config_stride
scale := config_scale
shrink := config_shrink
Expand Down
14 changes: 12 additions & 2 deletions src/main/scala/gemmini/LoopConv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class LoopConvLdInputReq(val coreMaxAddrBits: Int, val large_iterator_bitwidth:
val dram_addr = UInt(coreMaxAddrBits.W)
val downsample = Bool()
val max_pixels_per_row = UInt(small_iterator_bitwidth.W)
val padding_value = UInt(8.W) // TODO magic number
val input_dilated = Bool()
val trans_input_3120 = Bool()
val loop_id = UInt(log2Up(concurrent_loops).W)
Expand Down Expand Up @@ -320,7 +321,10 @@ class LoopConvLdInput(block_size: Int, coreMaxAddrBits: Int, large_iterator_bitw
config_cmd_rs1._unused := 1.U
config_cmd.rs1 := config_cmd_rs1.asUInt()

config_cmd.rs2 := dram_stride << req.downsample
val config_rs2 = Wire(Vec(2,UInt(32.W)))
config_rs2(0) := dram_stride << req.downsample
config_rs2(1) := req.padding_value
config_cmd.rs2 := config_rs2.asUInt

val mvin_cmd = Wire(new RoCCCommand)
mvin_cmd := DontCare
Expand Down Expand Up @@ -499,7 +503,10 @@ class LoopConvLdWeight(block_size: Int, coreMaxAddrBits: Int, large_iterator_bit
config_cmd_rs1._unused := 1.U
config_cmd.rs1 := config_cmd_rs1.asUInt

config_cmd.rs2 := dram_stride
val config_rs2 = Wire(Vec(2,UInt(32.W)))
config_rs2(0) := dram_stride
config_rs2(1) := 0.U
config_cmd.rs2 := config_rs2.asUInt

val mvin_cmd = Wire(new RoCCCommand)
mvin_cmd := DontCare
Expand Down Expand Up @@ -1070,6 +1077,7 @@ class LoopConvState(val block_size: Int, val large_iterator_bitwidth: Int, val s
val dw = Bool()

val max_pixels_per_row = UInt(small_iterator_bitwidth.W)
val padding_value = UInt(8.W) // TODO magic number

val configured = Bool()

Expand Down Expand Up @@ -1327,6 +1335,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:

is (LOOP_CONV_WS) {
loop_being_configured.no_bias := cmd.bits.cmd.rs1(0)
loop_being_configured.padding_value := cmd.bits.rs1(55, 48)

// TODO we added a default value for max_pixels_per_row just to maintain backwards compatibility. we should deprecate and remove it later
val config_max_pixels_per_row = cmd.bits.cmd.rs1(15, 8)
Expand Down Expand Up @@ -1394,6 +1403,7 @@ class LoopConv (block_size: Int, coreMaxAddrBits: Int, reservation_station_size:
ld_input.io.req.bits.input_dilated := loop_requesting_ld_input.input_dilated
ld_input.io.req.bits.trans_input_3120 := loop_requesting_ld_input.trans_input_3120
ld_input.io.req.bits.loop_id := loop_requesting_ld_input_id
ld_input.io.req.bits.padding_value := loop_requesting_ld_input.padding_value

ld_input.io.req.valid := !loop_requesting_ld_input.ld_input_started && loop_requesting_ld_input.configured

Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/gemmini/Scratchpad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ScratchpadMemReadRequest[U <: Data](local_addr_t: LocalAddr, scale_t_bits:
val pixel_repeats = UInt(8.W) // TODO magic numbers
val cmd_id = UInt(8.W) // TODO don't use a magic number here
val status = new MStatus

val padding_value = UInt(8.W) // TODO magic numbers
}

class ScratchpadMemWriteRequest(local_addr_t: LocalAddr, acc_t_bits: Int, scale_t_bits: Int)
Expand Down Expand Up @@ -543,7 +543,8 @@ class Scratchpad[T <: Data, U <: Data, V <: Data](config: GemminiArrayConfig[T,
mvin_scale_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals
}.elsewhen (zerowrite) {
bio.write.addr := zero_writer_pixel_repeater.io.resp.bits.laddr.sp_row()
bio.write.data := 0.U
val paddingdata = VecInit(Seq.fill(spad_w/8)(zero_writer_pixel_repeater.io.req.bits.tag.padding_value))
bio.write.data := paddingdata.asUInt
bio.write.mask := zero_writer_pixel_repeater.io.resp.bits.mask

zero_writer_pixel_repeater.io.resp.ready := true.B // TODO we combinationally couple valid and ready signals
Expand Down