-
Notifications
You must be signed in to change notification settings - Fork 30
/
MemCtrl.scala
146 lines (136 loc) · 6.91 KB
/
MemCtrl.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package dla.diplomatic
import chisel3._
import chisel3.util._
import dla.pe.PESizeConfig
import freechips.rocketchip.util._
case class EyerissMemCtrlParameters (
addressBits: Int,
inActSizeBits: Int,
weightSizeBits: Int,
pSumSizeBits: Int,
inActIds: Int, // the number of inAct source id
weightIds: Int, // the number of weight source id
pSumIds: Int// the number of pSum source id
)
class MemCommonIO(val nIds: Int, val addressBits: Int, val sizeBits: Int) extends Bundle {
val address: UInt = Output(UInt(addressBits.W))
val sourceAlloc: DecoupledIO[UInt] = Decoupled(UInt(log2Ceil(nIds).W))
val sourceFree: DecoupledIO[UInt] = Flipped(Decoupled(UInt(log2Ceil(nIds).W)))
val startAdr: UInt = Input(UInt(addressBits.W))
val reqSize: UInt = Input(UInt(sizeBits.W))
}
class EyerissMemCtrlIO()(implicit val p: EyerissMemCtrlParameters) extends Bundle {
val inActIO = new MemCommonIO(nIds = p.inActIds, addressBits = p.addressBits, sizeBits = p.inActSizeBits)
val weightIO = new MemCommonIO(nIds = p.weightIds, addressBits = p.addressBits, sizeBits = p.weightSizeBits)
val pSumIO = new MemCommonIO(nIds = p.pSumIds, addressBits = p.addressBits, sizeBits = p.pSumSizeBits)
}
/** This module can generate the address, sourceId, which are used in TileLink get/put.
* Also, this module is able to manage all the source id.
* */
class EyerissMemCtrlModule()(implicit val p: EyerissMemCtrlParameters) extends Module
with PESizeConfig {
val io: EyerissMemCtrlIO = IO(new EyerissMemCtrlIO()(p))
protected val inActIdMap: EyerissIDMapGenerator = Module(new EyerissIDMapGenerator(p.inActIds))
inActIdMap.suggestName("inActIdMap")
protected val weightIdMap: EyerissIDMapGenerator = Module(new EyerissIDMapGenerator(p.weightIds))
weightIdMap.suggestName("weightIdMap")
protected val pSumIdMap: EyerissIDMapGenerator = Module(new EyerissIDMapGenerator(p.pSumIds))
pSumIdMap.suggestName("pSumIdMap")
protected val inActIdMapIO: EyerissIDMapGenIO = inActIdMap.io
protected val weightIdMapIO: EyerissIDMapGenIO = weightIdMap.io
protected val pSumIdMapIO: EyerissIDMapGenIO = pSumIdMap.io
protected val inActStarAdrReg: UInt = RegInit(0.U(p.addressBits.W))
protected val weightStarAdrReg: UInt = RegInit(0.U(p.addressBits.W))
protected val pSumStarAdrReg: UInt = RegInit(0.U(p.addressBits.W))
protected val inActReqAdrWire: UInt = Wire(UInt(p.addressBits.W))
protected val weightReqAdrWire: UInt = Wire(UInt(p.addressBits.W))
protected val pSumReqAdrWire: UInt = Wire(UInt(p.addressBits.W))
protected val inActReqSizeReg: UInt = RegInit(0.U(p.inActSizeBits.W))
inActReqSizeReg.suggestName("inActReqSizeReg")
protected val weightReqSizeReg: UInt = RegInit(0.U(p.weightSizeBits.W))
protected val pSumReqSizeReg: UInt = RegInit(0.U(p.pSumSizeBits.W))
/** connections of input and source generate module */
io.inActIO.sourceAlloc <> inActIdMapIO.alloc
io.inActIO.sourceFree <> inActIdMapIO.free
io.weightIO.sourceAlloc <> weightIdMapIO.alloc
io.weightIO.sourceFree <> weightIdMapIO.free
io.pSumIO.sourceAlloc <> pSumIdMapIO.alloc
io.pSumIO.sourceFree <> pSumIdMapIO.free
/** the start address */
inActStarAdrReg := io.inActIO.startAdr
weightStarAdrReg := io.weightIO.startAdr
pSumStarAdrReg := io.pSumIO.startAdr
/** the require size from decoder module*/
inActReqSizeReg := io.inActIO.reqSize
weightReqSizeReg := io.weightIO.reqSize
pSumReqSizeReg := io.pSumIO.reqSize
/** each require address */
protected val inActReqAdrAcc: UInt = RegInit(0.U(p.addressBits.W))
inActReqAdrAcc.suggestName("inActReqAdrAcc")
protected val weightReqAdrAcc: UInt = RegInit(0.U(p.addressBits.W))
weightReqAdrAcc.suggestName("weightReqAdrAcc")
protected val pSumReqAdrAcc: UInt = RegInit(0.U(p.addressBits.W))
protected val inActReqFinOnce: Bool = RegInit(false.B) // true when have finished once
inActReqFinOnce.suggestName("inActReqFinOnce")
/** as inAct needs require 2 times of SRAM number
* while `inActReqFinOnce && inActIdMapIO.finish` then that's the second time
* and it's real finish */
inActReqFinOnce := Mux(inActIdMapIO.finish, !inActReqFinOnce, inActReqFinOnce)
inActReqAdrAcc := Mux(inActReqFinOnce && inActIdMapIO.finish, 0.U,
Mux(inActIdMapIO.alloc.fire(), inActReqAdrAcc + inActReqSizeReg, inActReqAdrAcc)
)
inActReqAdrWire := inActStarAdrReg + inActReqAdrAcc
weightReqAdrAcc := Mux(weightIdMapIO.finish, 0.U,
Mux(weightIdMapIO.alloc.fire(), weightReqAdrAcc + weightReqSizeReg, weightReqAdrAcc)
)
weightReqAdrWire := weightStarAdrReg + weightReqAdrAcc
pSumReqAdrAcc := Mux(pSumIdMapIO.finish, 0.U,
Mux(pSumIdMapIO.alloc.fire(), pSumReqAdrAcc + pSumReqSizeReg, pSumReqAdrAcc)
)
pSumReqAdrWire := pSumStarAdrReg + pSumReqAdrAcc
/** connections of require address */
io.inActIO.address := inActReqAdrWire
io.weightIO.address := weightReqAdrWire
io.pSumIO.address := pSumReqAdrWire
}
class EyerissIDMapGenIO(val sourceWidth: Int) extends Bundle {
val free: DecoupledIO[UInt] = Flipped(Decoupled(UInt(sourceWidth.W)))
val alloc: DecoupledIO[UInt] = Decoupled(UInt(sourceWidth.W))
val finish: Bool = Output(Bool())
}
class EyerissIDMapGenerator(val numIds: Int) extends Module {
require(numIds > 0)
private val w = log2Up(numIds)
val io: EyerissIDMapGenIO = IO(new EyerissIDMapGenIO(w))
io.free.ready := true.B
/** [[reqBitmap]] true indicates that the id hasn't send require signal;
* [[respBitmap]] true indicates that the id has received response;
* both of them have [[numIds]] bits, and each bit represents one id;
* */
protected val reqBitmap: UInt = RegInit(((BigInt(1) << numIds) - 1).U(numIds.W)) // True indicates that the id is available
protected val respBitmap: UInt = RegInit(0.U(numIds.W)) // false means haven't receive response
/** [[select]] is the oneHot code which represents the lowest bit that equals to true;
* Then use `OHToUInt` to get its binary value.
* */
protected val select: UInt = (~(leftOR(reqBitmap) << 1)).asUInt & reqBitmap
io.alloc.bits := OHToUInt(select)
io.alloc.valid := reqBitmap.orR() // valid when there is any id hasn't sent require signal
protected val clr: UInt = WireDefault(0.U(numIds.W))
when(io.alloc.fire()) {
clr := UIntToOH(io.alloc.bits)
}
protected val set: UInt = WireDefault(0.U(numIds.W))
when(io.free.fire()) {
set := UIntToOH(io.free.bits) // this is the sourceId that finishes
}
respBitmap := respBitmap | set
reqBitmap := (reqBitmap & (~clr).asUInt)
/** when all the sources receive response*/
protected val finishWire: Bool = respBitmap.andR()
when (finishWire) {
respBitmap := 0.U
reqBitmap := ((BigInt(1) << numIds) - 1).U
}
io.finish := finishWire
//assert(!io.free.valid || !(reqBitmap & (~clr).asUInt) (io.free.bits)) // No double freeing
}