Skip to content

Commit

Permalink
modifying ex controller for packed b vector
Browse files Browse the repository at this point in the history
  • Loading branch information
Minh Nguyen authored and Minh Nguyen committed May 6, 2024
1 parent 94895ad commit ef19f88
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions src/main/scala/gemmini/ExecuteController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val bd_transpose = Reg(Bool())
val config_initialized = RegInit(false.B)

val is_gemv = WireInit(true.B)

val a_should_be_fed_into_transposer = Mux(current_dataflow === Dataflow.OS.id.U, !a_transpose, a_transpose)
val a_address_place = Mux(preload_cmd_place === 0.U, 1.U, Mux(a_should_be_fed_into_transposer, 2.U, 0.U))

Expand Down Expand Up @@ -250,10 +252,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
val c_addr_stride = Reg(UInt(16.W)) // TODO magic numbers

val a_address = (0 until tileColumns).map(i => a_address_rs1(i) + a_addr_offset(i))
val b_address = b_address_rs2 + b_fire_counter
dontTouch(b_address)
dontTouch(b_address_rs2)
val d_address = d_address_rs1 + (block_size.U - 1.U - d_fire_counter)
val b_address = Mux(is_gemv, b_address_rs2, b_address_rs2 + b_fire_counter)
val d_address = Mux(is_gemv, d_address_rs1, d_address_rs1 + (block_size.U - 1.U - d_fire_counter))
dontTouch(d_address)

val dataAbank = a_address.map(address => address.sp_bank())
Expand Down Expand Up @@ -456,8 +456,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
for (i <- 0 until sp_banks) {
// val matching_a = dataAbank.indexOf(i.U)
val matching_a = if (i < 4) i else -1; // TODO temp fix bc indexOf() doesn't work for some reason
val matching_a_wire = WireInit(matching_a.S(4.W));
dontTouch(matching_a_wire)
val read_a = if (matching_a == -1) false.B else a_valid(matching_a) && !a_read_from_acc && start_inputting_a && !multiply_garbage && a_row_is_not_all_zeros(matching_a) && !(im2col_wire&&im2col_en)
val read_b = b_valid && !b_read_from_acc && dataBbank === i.U && start_inputting_b && !accumulate_zeros && b_row_is_not_all_zeros //&& !im2col_wire
dontTouch(b_valid)
Expand Down Expand Up @@ -604,6 +602,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
}
}

// is_gemv := config_ex_rs1.is_gemv.asBool

a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala
c_addr_stride := config_ex_rs2.c_stride // TODO this needs to be kept in sync with ROB.scala
config_initialized := true.B
Expand Down Expand Up @@ -631,7 +631,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In

//start_inputting_a := current_dataflow === Dataflow.OS.id.U
//start_inputting_d := true.B

start_inputting_a := a_should_be_fed_into_transposer
start_inputting_b := b_should_be_fed_into_transposer
start_inputting_d := true.B
Expand Down Expand Up @@ -925,37 +925,47 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In
}
}

// TODO integrate this fully
val gemv_mode = RegInit(true.B)
dontTouch(dataB)
dontTouch(cntl_valid)
dontTouch(mesh.io.a.valid)
dontTouch(dataD)
dontTouch(is_gemv)

when (gemv_mode) {
when (is_gemv) {
when ((current_dataflow === Dataflow.WS.id.U).asBool) {
// transpose A
for (tc <- 0 until tileColumns) {
for (mr <- 0 until meshRows) {
for (tr <- 0 until tileRows) {
mesh.io.a.bits(mr)(tc)(tr) := dataA.asTypeOf(Vec(tileColumns, Vec(meshRows, Vec(tileRows, inputType))))(tc)(mr)(tr)
}
}
}
// pass in duplicated elements of weights vector in reverse order
for (tc <- 0 until tileColumns) {
for (mc <- 0 until meshColumns) {
mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(0)
mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(meshRows.U - d_fire_counter)
}
}
// duplicate one element of the bias vector to the mesh
for (tc <- 0 until tileColumns) {
for (mc <- 0 until meshColumns) {
mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(b_fire_counter-1.U)
}
}
mesh.io.b.bits := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))
}.otherwise {
// TODO this only works when casted this way
mesh.io.a.bits := dataA.asTypeOf(Vec(meshRows, Vec(tileColumns, Vec(tileRows, inputType))))
for (tc <- 0 until tileColumns) {
for (mc <- 0 until meshColumns) {
mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(0)
mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(b_fire_counter-1.U)
}
}
for (tc <- 0 until tileColumns) {
for (mc <- 0 until meshColumns) {
mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(d_fire_counter-1.U)
}
}
mesh.io.d.bits := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))
}
}.otherwise {
for (tc <- 0 until tileColumns) {
Expand Down

0 comments on commit ef19f88

Please sign in to comment.