From ef19f88dc8224dfc2ca0c12f78c828f888d01a91 Mon Sep 17 00:00:00 2001 From: Minh Nguyen Date: Mon, 6 May 2024 15:36:06 -0700 Subject: [PATCH] modifying ex controller for packed b vector --- .../scala/gemmini/ExecuteController.scala | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/main/scala/gemmini/ExecuteController.scala b/src/main/scala/gemmini/ExecuteController.scala index e87d036e..cdb10288 100644 --- a/src/main/scala/gemmini/ExecuteController.scala +++ b/src/main/scala/gemmini/ExecuteController.scala @@ -120,6 +120,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val bd_transpose = Reg(Bool()) val config_initialized = RegInit(false.B) + val is_gemv = WireInit(true.B) + val a_should_be_fed_into_transposer = Mux(current_dataflow === Dataflow.OS.id.U, !a_transpose, a_transpose) val a_address_place = Mux(preload_cmd_place === 0.U, 1.U, Mux(a_should_be_fed_into_transposer, 2.U, 0.U)) @@ -250,10 +252,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In val c_addr_stride = Reg(UInt(16.W)) // TODO magic numbers val a_address = (0 until tileColumns).map(i => a_address_rs1(i) + a_addr_offset(i)) - val b_address = b_address_rs2 + b_fire_counter - dontTouch(b_address) - dontTouch(b_address_rs2) - val d_address = d_address_rs1 + (block_size.U - 1.U - d_fire_counter) + val b_address = Mux(is_gemv, b_address_rs2, b_address_rs2 + b_fire_counter) + val d_address = Mux(is_gemv, d_address_rs1, d_address_rs1 + (block_size.U - 1.U - d_fire_counter)) dontTouch(d_address) val dataAbank = a_address.map(address => address.sp_bank()) @@ -456,8 +456,6 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In for (i <- 0 until sp_banks) { // val matching_a = dataAbank.indexOf(i.U) val matching_a = if (i < 4) i else -1; // TODO temp fix bc indexOf() doesn't work for some reason - val matching_a_wire = WireInit(matching_a.S(4.W)); - dontTouch(matching_a_wire) val read_a = if (matching_a == -1) false.B else a_valid(matching_a) && !a_read_from_acc && start_inputting_a && !multiply_garbage && a_row_is_not_all_zeros(matching_a) && !(im2col_wire&&im2col_en) val read_b = b_valid && !b_read_from_acc && dataBbank === i.U && start_inputting_b && !accumulate_zeros && b_row_is_not_all_zeros //&& !im2col_wire dontTouch(b_valid) @@ -604,6 +602,8 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } } + // is_gemv := config_ex_rs1.is_gemv.asBool + a_addr_stride := config_ex_rs1.a_stride // TODO this needs to be kept in sync with ROB.scala c_addr_stride := config_ex_rs2.c_stride // TODO this needs to be kept in sync with ROB.scala config_initialized := true.B @@ -631,7 +631,7 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In //start_inputting_a := current_dataflow === Dataflow.OS.id.U //start_inputting_d := true.B - + start_inputting_a := a_should_be_fed_into_transposer start_inputting_b := b_should_be_fed_into_transposer start_inputting_d := true.B @@ -925,15 +925,15 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } } - // TODO integrate this fully - val gemv_mode = RegInit(true.B) dontTouch(dataB) dontTouch(cntl_valid) dontTouch(mesh.io.a.valid) dontTouch(dataD) + dontTouch(is_gemv) - when (gemv_mode) { + when (is_gemv) { when ((current_dataflow === Dataflow.WS.id.U).asBool) { + // transpose A for (tc <- 0 until tileColumns) { for (mr <- 0 until meshRows) { for (tr <- 0 until tileRows) { @@ -941,21 +941,31 @@ class ExecuteController[T <: Data, U <: Data, V <: Data](xLen: Int, tagWidth: In } } } + // pass in duplicated elements of weights vector in reverse order for (tc <- 0 until tileColumns) { for (mc <- 0 until meshColumns) { - mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(0) + mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(meshRows.U - d_fire_counter) + } + } + // duplicate one element of the bias vector to the mesh + for (tc <- 0 until tileColumns) { + for (mc <- 0 until meshColumns) { + mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(b_fire_counter-1.U) } } - mesh.io.b.bits := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType))) }.otherwise { // TODO this only works when casted this way mesh.io.a.bits := dataA.asTypeOf(Vec(meshRows, Vec(tileColumns, Vec(tileRows, inputType)))) for (tc <- 0 until tileColumns) { for (mc <- 0 until meshColumns) { - mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(0) + mesh.io.b.bits(mc)(tc) := dataB.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(b_fire_counter-1.U) + } + } + for (tc <- 0 until tileColumns) { + for (mc <- 0 until meshColumns) { + mesh.io.d.bits(mc)(tc) := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType)))(0)(d_fire_counter-1.U) } } - mesh.io.d.bits := dataD.asTypeOf(Vec(meshColumns, Vec(tileColumns, inputType))) } }.otherwise { for (tc <- 0 until tileColumns) {