Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
SeahK committed Feb 7, 2024
1 parent 8c8b38b commit 25710ae
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 13 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "software/libgemmini"]
path = software/libgemmini
url = https://github.com/ucb-bar/libgemmini.git
[submodule "software/libdma"]
path = software/libdma
url = https://github.com/ucb-bar/libdma
1 change: 1 addition & 0 deletions software/libdma
Submodule libdma added at 3e0712
2 changes: 1 addition & 1 deletion src/main/scala/gemmini/AccumulatorScale.scala
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ object AccumulatorScale {
val neg_q_iexp = neg(q)
val z_iexp = (neg_q_iexp * qln2_inv).asUInt.do_>>(16).asTypeOf(q) // q is non-positive
val z_iexp_saturated = Wire(z_iexp.cloneType)
z_iexp_saturated := Mux((5 until 16).map(z_iexp.asUInt(_)).reduce(_ | _), 32.S, z_iexp)
z_iexp_saturated := Mux((5 until 16).map(z_iexp.asUInt(_)).reduce(_ | _), 32.S.asTypeOf(z_iexp), z_iexp)
val qp_iexp = q.mac(z_iexp, qln2).withWidthOf(q)
val q_poly_iexp = qc.mac(qp_iexp + qb, qp_iexp + qb).withWidthOf(q)
// we dont want a rounding shift
Expand Down
39 changes: 38 additions & 1 deletion src/main/scala/gemmini/ConfigsFP.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,33 @@ object GemminiFPConfigs {
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
mvin_scale_acc_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 4, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")),
)


val slamFPConfig = FP32DefaultConfig.copy(sp_capacity=CapacityInKilobytes(32), acc_capacity=CapacityInKilobytes(16), dataflow=Dataflow.WS,
//acc_scale_args=Some(defaultFPConfig.acc_scale_args.get.copy(num_scale_units=0, latency=1)),
acc_scale_args = Some(ScaleArguments((t: Float, u: Float) => {t}, 1, Float(8, 24), -1, identity = "1.0",
c_str = "((x))"
)),
mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => t * u, 3, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), // 4-> 3 (check)
//mvin_scale_args = Some(ScaleArguments((t: Float, u: Float) => {Mux(u > 0.U.asTypeOf(Float(8, 24)), t, 0.U.asTypeOf(Float(8,24)) - t)}, 1, Float(8, 24), -1, identity = "1.0", c_str="((x) * (scale))")), // 2 -> 1 stage
mvin_scale_acc_args=None,
acc_singleported=false,
acc_sub_banks = 1,
acc_banks = 2,
mesh_output_delay = 2,
tile_latency = 1,
acc_latency = 3,
ex_read_from_acc=false,
ex_write_to_spad=false,
has_training_convs = false,
hardcode_d_to_garbage_addr = true,
acc_read_full_width = false,
//has_loop_conv = false,
max_in_flight_mem_reqs = 16,
headerFileName = "gemmini_params_fp32.h",
num_counter = 0,
clock_gate = true // enable this
)

//FP16 Half Precision Configuration
val FP16DefaultConfig = defaultFPConfig.copy(inputType = Float(5, 11), spatialArrayOutputType = Float(5, 11), accType = Float(8, 24),
tile_latency = 2,
Expand Down Expand Up @@ -123,6 +149,17 @@ class GemminiFP32DefaultConfig extends Config((site, here, up) => {
)
})

class SLAMFPGemminiConfig[T <: Data : Arithmetic, U <: Data, V <: Data](
gemminiConfig: GemminiArrayConfig[T,U,V] = GemminiFPConfigs.slamFPConfig
) extends Config((site, here, up) => {
case BuildRoCC => up(BuildRoCC) ++ Seq(
(p: Parameters) => {
implicit val q = p
val gemmini = LazyModule(new Gemmini(gemminiConfig))
gemmini
}
)
})

//===========FP16 Default Config=========
class GemminiFP16DefaultConfig extends Config((site, here, up) => {
Expand Down
38 changes: 27 additions & 11 deletions src/main/scala/gemmini/LoopMatmul.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val blocks = Mux(col_iterator + max_blocks <= max_col_iterator, max_blocks, max_col_iterator-col_iterator)
val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U)
val rows = block_size.U - Mux(row_iterator === max_row_iterator-1.U, row_pad, 0.U)
dontTouch(rows)
dontTouch(cols)
dontTouch(sp_addr)

val mvin_cmd = Wire(new RoCCCommand)
mvin_cmd := DontCare
Expand All @@ -82,7 +85,7 @@ class LoopMatmulLdA(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
mvin_cmd_rs2.local_addr := cast_to_sp_addr(mvin_cmd_rs2.local_addr, sp_addr)
mvin_cmd.rs2 := mvin_cmd_rs2.asUInt
when(req.is_resadd){
mvin_cmd_rs2.local_addr := cast_to_acc_addr(mvin_cmd_rs2.local_addr, sp_addr, accumulate = false.B, read_full = false.B)
mvin_cmd_rs2.local_addr := cast_to_acc_addr(mvin_cmd_rs2.local_addr, sp_addr, accumulate = true.B, read_full = false.B)
}

io.req.ready := state === idle
Expand Down Expand Up @@ -182,7 +185,10 @@ class LoopMatmulLdB(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val sp_addr = sp_addr_start + (row_iterator * max_col_iterator + col_iterator) * block_size.U
val blocks = Mux(col_iterator + max_blocks <= max_col_iterator, max_blocks, max_col_iterator-col_iterator)
val cols = (blocks * block_size.U) - Mux(col_iterator + blocks >= max_col_iterator, col_pad, 0.U)
val rows = block_size.U - Mux(max_row_iterator === max_row_iterator-1.U, row_pad, 0.U)
val rows = block_size.U - Mux(row_iterator === max_row_iterator-1.U, row_pad, 0.U)
dontTouch(rows)
dontTouch(cols)
dontTouch(sp_addr)

val mvin_cmd = Wire(new RoCCCommand)
mvin_cmd := DontCare
Expand Down Expand Up @@ -248,6 +254,7 @@ class LoopMatmulLdDReq(val block_size: Int, val coreMaxAddrBits: Int, val iterat
val low_d = Bool()
val addr_start = UInt(log2Up(max_acc_addr).W)
val loop_id = UInt(log2Up(concurrent_loops).W)
val is_resadd = Bool()
}

class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: Int, max_acc_addr: Int, input_w: Int,
Expand Down Expand Up @@ -281,12 +288,16 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In

val dram_offset = Mux(req.low_d, (i * req.dram_stride + j) * block_size.U * (input_w/8).U,
(i * req.dram_stride + j) * block_size.U * (acc_w/8).U)
val dram_addr = req.dram_addr + LoopMatmul.castDramOffset(dram_offset)
val dram_addr = Mux(req.is_resadd, 0.U, req.dram_addr + LoopMatmul.castDramOffset(dram_offset))
val sp_addr = acc_addr_start + (i * req.max_j + j) * block_size.U
val blocks = Mux(j + max_blocks <= req.max_j, max_blocks, req.max_j-j)
val cols = (blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U)
val rows = block_size.U - Mux(i === req.max_i-1.U, req.pad_i, 0.U)

dontTouch(rows)
dontTouch(cols)
dontTouch(sp_addr)

val mvin_cmd = Wire(new RoCCCommand)
mvin_cmd := DontCare
mvin_cmd.inst.funct := LOAD3_CMD
Expand All @@ -303,12 +314,12 @@ class LoopMatmulLdD(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
io.idle := state === idle

// The order here is k, j, i
io.cmd.valid := state =/= idle && !io.rob_overloaded && req.dram_addr =/= 0.U
io.cmd.valid := state =/= idle && !io.rob_overloaded && !(req.dram_addr === 0.U && !req.is_resadd)
io.cmd.bits := mvin_cmd

io.loop_id := req.loop_id

when (req.dram_addr === 0.U) {
when (req.dram_addr === 0.U && !req.is_resadd) {
state := idle
}.elsewhen (io.cmd.fire) {
// The order here is k, j, i
Expand Down Expand Up @@ -554,6 +565,9 @@ class LoopMatmulStC(block_size: Int, coreMaxAddrBits: Int, iterator_bitwidth: In
val blocks = Mux(j + max_blocks <= req.max_j, max_blocks, req.max_j-j)
val cols = (blocks * block_size.U) - Mux(j + blocks >= req.max_j, req.pad_j, 0.U)
val rows = block_size.U - Mux(i === req.max_i-1.U, req.pad_i, 0.U)
dontTouch(rows)
dontTouch(cols)
dontTouch(sp_addr)

val mvout_cmd = Wire(new RoCCCommand)
mvout_cmd := DontCare
Expand Down Expand Up @@ -809,6 +823,10 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size

io.busy := cmd.valid || loop_configured

// Create reservation station utilization counters
val ld_utilization = RegInit(0.U(log2Up(max_lds+1).W))
val st_utilization = RegInit(0.U(log2Up(max_sts+1).W))
val ex_utilization = RegInit(0.U(log2Up(max_exs+1).W))
// Create ld arbiters
val ldab_arb = Module(new WeightedArbiter(new RoCCCommand(), maxWeightA=255, staticWeightAEnabled=true)) // TODO magic numbers
ldab_arb.io.inA <> ldA.io.cmd
Expand All @@ -818,6 +836,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
val forceB = !ab_loads_on_same_loop && ldB.io.loop_id === head_loop_id
ldab_arb.io.forceA := Mux(is_resadd, ab_loads_on_same_loop && !ldA.io.idle, forceA)
ldab_arb.io.forceB := Mux(is_resadd, forceB || ldA.io.idle, forceB)
//ldab_arb.io.forceB := Mux(is_resadd, (forceB || ldA.io.idle) && (ld_utilization === 0.U), forceB)
ldab_arb.io.weightA := 0.U
ldab_arb.io.inA_idle := ldA.io.idle
ldab_arb.io.inB_idle := ldB.io.idle
Expand All @@ -834,11 +853,6 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
arb.io.in(3) <> ldab_arb.io.out
val unrolled_cmd = arb.io.out

// Create reservation station utilization counters
val ld_utilization = RegInit(0.U(log2Up(max_lds+1).W))
val st_utilization = RegInit(0.U(log2Up(max_sts+1).W))
val ex_utilization = RegInit(0.U(log2Up(max_exs+1).W))

ld_utilization := ld_utilization +& (ldA.io.cmd.fire || ldB.io.cmd.fire || ldD.io.cmd.fire) -& io.ld_completed
st_utilization := st_utilization +& stC.io.cmd.fire -& io.st_completed
ex_utilization := ex_utilization +& ex.io.cmd.fire -& io.ex_completed
Expand All @@ -859,7 +873,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
io.out.bits.from_conv_fsm := Mux(loop_configured, false.B, cmd.bits.from_conv_fsm)
io.out.valid := Mux(loop_configured, unrolled_cmd.valid, cmd.valid && !is_loop_config_cmd && !is_loop_run_cmd)

cmd.ready := Mux(is_loop_cmd, !loop_being_configured.configured, !loop_configured && io.out.ready)
cmd.ready := Mux(is_loop_cmd, !loop_being_configured.configured && !(is_resadd && ld_utilization > 0.U), !loop_configured && io.out.ready)
arb.io.out.ready := io.out.ready

// Wire up overloaded signals
Expand Down Expand Up @@ -1035,6 +1049,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
ldD.io.req.bits.low_d := loop_requesting_ldD.low_d
ldD.io.req.bits.addr_start := ld_d_addr_start
ldD.io.req.bits.loop_id := loop_requesting_ldD_id
ldD.io.req.bits.is_resadd := is_resadd

ldD.io.req.valid := !loop_requesting_ldD.ldd_started && loop_requesting_ldD.configured

Expand Down Expand Up @@ -1077,6 +1092,7 @@ class LoopMatmul(block_size: Int, coreMaxAddrBits: Int, reservation_station_size
when(is_resadd){
ldA.io.req.bits.addr_start := loop_requesting_ldA.resadd_addr_start
ldB.io.req.bits.addr_end := loop_requesting_ldB.resadd_addr_start
ldD.io.req.bits.addr_start := loop_requesting_ldD.resadd_addr_start
stC.io.req.bits.addr_start := loop_requesting_st.resadd_addr_start
stC.io.req.valid := !loop_requesting_st.st_started && loop_requesting_st.configured
}
Expand Down

0 comments on commit 25710ae

Please sign in to comment.