Skip to content

Commit

Permalink
Merge pull request #171 from ucb-bar/dev
Browse files Browse the repository at this point in the history
Merge Dev Into Master
  • Loading branch information
hngenc authored Dec 7, 2021
2 parents 6c620f5 + f2f74f6 commit 44b28e4
Show file tree
Hide file tree
Showing 27 changed files with 935 additions and 302 deletions.
2 changes: 1 addition & 1 deletion SPIKE.hash
Original file line number Diff line number Diff line change
@@ -1 +1 @@
34741e07bc6b56f1762ce579537948d58e28cd5a
02e2d983cc8e2c385ebe920302c427b9167bd76e
253 changes: 176 additions & 77 deletions src/main/scala/gemmini/AccumulatorMem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,59 @@ class AccumulatorWriteReq[T <: Data: Arithmetic](n: Int, t: Vec[Vec[T]]) extends
override def cloneType: this.type = new AccumulatorWriteReq(n, t).asInstanceOf[this.type]
}

class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U) extends Bundle {

class AccumulatorMemIO [T <: Data: Arithmetic, U <: Data](n: Int, t: Vec[Vec[T]], scale_t: U,
acc_sub_banks: Int, use_shared_ext_mem: Boolean
) extends Bundle {
val read = Flipped(new AccumulatorReadIO(n, log2Ceil(t.head.head.getWidth), t, scale_t))
// val write = Flipped(new AccumulatorWriteIO(n, t))
val write = Flipped(Decoupled(new AccumulatorWriteReq(n, t)))

override def cloneType: this.type = new AccumulatorMemIO(n, t, scale_t).asInstanceOf[this.type]
val ext_mem = if (use_shared_ext_mem) Some(Vec(acc_sub_banks, new ExtMemIO)) else None

val adder = new Bundle {
val valid = Output(Bool())
val op1 = Output(t.cloneType)
val op2 = Output(t.cloneType)
val sum = Input(t.cloneType)
}

override def cloneType: this.type = new AccumulatorMemIO(n, t, scale_t, acc_sub_banks, use_shared_ext_mem).asInstanceOf[this.type]
}

class AccPipe[T <: Data : Arithmetic](latency: Int, t: T)(implicit ev: Arithmetic[T]) extends Module {
val io = IO(new Bundle {
val op1 = Input(t.cloneType)
val op2 = Input(t.cloneType)
val sum = Output(t.cloneType)
})
import ev._
io.sum := ShiftRegister(io.op1 + io.op2, latency)
}

class AccPipeShared[T <: Data : Arithmetic](latency: Int, t: Vec[Vec[T]], banks: Int) extends Module {
val io = IO(new Bundle {
val in_sel = Input(Vec(banks, Bool()))
val ina = Input(Vec(banks, t.cloneType))
val inb = Input(Vec(banks, t.cloneType))
val out = Output(t.cloneType)
})
val ina = Mux1H(io.in_sel, io.ina)
val inb = Mux1H(io.in_sel, io.inb)
io.out := VecInit((ina zip inb).map { case (rv, wv) =>
VecInit((rv zip wv).map { case (re, we) =>
val m = Module(new AccPipe(latency, t.head.head.cloneType))
m.io.op1 := re
m.io.op2 := we
m.io.sum
})
})
}

class AccumulatorMem[T <: Data, U <: Data](
n: Int, t: Vec[Vec[T]], scale_func: (T, U) => T, scale_t: U,
acc_singleported: Boolean, acc_sub_banks: Int
n: Int, t: Vec[Vec[T]], scale_func: (T, U) => T, scale_t: U,
acc_singleported: Boolean, acc_sub_banks: Int,
use_shared_ext_mem: Boolean,
acc_latency: Int, acc_type: T
)
(implicit ev: Arithmetic[T]) extends Module {
// TODO Do writes in this module work with matrices of size 2? If we try to read from an address right after writing
Expand All @@ -69,54 +111,91 @@ class AccumulatorMem[T <: Data, U <: Data](
import ev._

// TODO unify this with TwoPortSyncMemIO
val io = IO(new AccumulatorMemIO(n, t, scale_t))


// For any write operation, we spend 2 cycles reading the existing address out, buffering it in a register, and then
// accumulating on top of it (if necessary)
val wdata_buf = ShiftRegister(io.write.bits.data, 2)
val waddr_buf = ShiftRegister(io.write.bits.addr, 2)
val acc_buf = ShiftRegister(io.write.bits.acc, 2)
val mask_buf = ShiftRegister(io.write.bits.mask, 2)
val w_buf_valid = ShiftRegister(io.write.fire(), 2)
val acc_rdata = Wire(t)
acc_rdata := DontCare
val read_rdata = Wire(t)
read_rdata := DontCare
val io = IO(new AccumulatorMemIO(n, t, scale_t, acc_sub_banks, use_shared_ext_mem))

require (acc_latency >= 2)

val pipelined_writes = Reg(Vec(acc_latency, Valid(new AccumulatorWriteReq(n, t))))
val oldest_pipelined_write = pipelined_writes(acc_latency-1)
pipelined_writes(0).valid := io.write.fire()
pipelined_writes(0).bits := io.write.bits
for (i <- 1 until acc_latency) {
pipelined_writes(i) := pipelined_writes(i-1)
}

val rdata_for_adder = Wire(t)
rdata_for_adder := DontCare
val rdata_for_read_resp = Wire(t)
rdata_for_read_resp := DontCare

val adder_sum = io.adder.sum
io.adder.valid := pipelined_writes(0).valid && pipelined_writes(0).bits.acc
io.adder.op1 := rdata_for_adder
io.adder.op2 := pipelined_writes(0).bits.data

val block_read_req = WireInit(false.B)
val w_sum = VecInit((RegNext(acc_rdata) zip wdata_buf).map { case (rv, wv) =>
VecInit((rv zip wv).map(t => t._1 + t._2))
})
val block_write_req = WireInit(false.B)

val mask_len = t.getWidth / 8
val mask_elem = UInt((t.getWidth / mask_len).W)

if (!acc_singleported) {
val mem = TwoPortSyncMem(n, t, t.getWidth / 8) // TODO We assume byte-alignment here. Use aligned_to instead
mem.io.waddr := waddr_buf
mem.io.wen := w_buf_valid
mem.io.wdata := Mux(acc_buf, w_sum, wdata_buf)
mem.io.mask := mask_buf
acc_rdata := mem.io.rdata
read_rdata := mem.io.rdata
require(!use_shared_ext_mem)
val mem = TwoPortSyncMem(n, t, mask_len) // TODO We assume byte-alignment here. Use aligned_to instead
mem.io.waddr := oldest_pipelined_write.bits.addr
mem.io.wen := oldest_pipelined_write.valid
mem.io.wdata := Mux(oldest_pipelined_write.bits.acc, adder_sum, oldest_pipelined_write.bits.data)
mem.io.mask := oldest_pipelined_write.bits.mask
rdata_for_adder := mem.io.rdata
rdata_for_read_resp := mem.io.rdata
mem.io.raddr := Mux(io.write.fire() && io.write.bits.acc, io.write.bits.addr, io.read.req.bits.addr)
mem.io.ren := io.read.req.fire() || (io.write.fire() && io.write.bits.acc)
} else {
val mask_len = t.getWidth / 8
val mask_elem = UInt((t.getWidth / mask_len).W)
val reads = Wire(Vec(2, Decoupled(UInt())))
reads(0).valid := io.write.valid && io.write.bits.acc
reads(0).bits := io.write.bits.addr
reads(0).ready := true.B
reads(1).valid := io.read.req.valid
reads(1).bits := io.read.req.bits.addr
reads(1).ready := true.B
block_read_req := !reads(1).ready
val rmw_req = Wire(Decoupled(UInt()))
rmw_req.valid := io.write.valid && io.write.bits.acc
rmw_req.bits := io.write.bits.addr
rmw_req.ready := true.B

block_write_req := !rmw_req.ready

val only_read_req = Wire(Decoupled(UInt()))
only_read_req.valid := io.read.req.valid
only_read_req.bits := io.read.req.bits.addr
only_read_req.ready := true.B

block_read_req := !only_read_req.ready

for (i <- 0 until acc_sub_banks) {
def isThisBank(addr: UInt) = addr(log2Ceil(acc_sub_banks)-1,0) === i.U
def getBankIdx(addr: UInt): UInt = (addr >> log2Ceil(acc_sub_banks)).asUInt()
val mem = SyncReadMem(n / acc_sub_banks, Vec(mask_len, mask_elem))
def getBankIdx(addr: UInt) = addr >> log2Ceil(acc_sub_banks)
val (read, write) = if (use_shared_ext_mem) {
def read(addr: UInt, ren: Bool): Data = {
io.ext_mem.get(i).read_en := ren
io.ext_mem.get(i).read_addr := addr
io.ext_mem.get(i).read_data
}
io.ext_mem.get(i).write_en := false.B
io.ext_mem.get(i).write_addr := DontCare
io.ext_mem.get(i).write_data := DontCare
io.ext_mem.get(i).write_mask := DontCare
def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = {
io.ext_mem.get(i).write_en := true.B
io.ext_mem.get(i).write_addr := addr
io.ext_mem.get(i).write_data := wdata.asUInt
io.ext_mem.get(i).write_mask := wmask.asUInt
}
(read _, write _)
} else {
val mem = SyncReadMem(n / acc_sub_banks, Vec(mask_len, mask_elem))
def read(addr: UInt, ren: Bool): Data = mem.read(addr, ren)
def write(addr: UInt, wdata: Vec[UInt], wmask: Vec[Bool]) = mem.write(addr, wdata, wmask)
(read _, write _)
}

val ren = WireInit(false.B)
val raddr = WireInit(getBankIdx(reads(0).bits))
val raddr = WireInit(getBankIdx(rmw_req.bits))
val nEntries = 3

// Writes coming 2 cycles after read leads to bad bank behavior
// Add another buffer here
class W_Q_Entry[T <: Data](mask_len: Int, mask_elem: T) extends Bundle {
Expand All @@ -126,25 +205,32 @@ class AccumulatorMem[T <: Data, U <: Data](
val addr = UInt(log2Ceil(n/acc_sub_banks).W)
override def cloneType: this.type = new W_Q_Entry(mask_len, mask_elem).asInstanceOf[this.type]
}

val w_q = Reg(Vec(nEntries, new W_Q_Entry(mask_len, mask_elem)))
for (e <- w_q) {
when (e.valid) {
assert(!(
io.write.valid && io.write.bits.acc &&
io.write.fire() && io.write.bits.acc &&
isThisBank(io.write.bits.addr) && getBankIdx(io.write.bits.addr) === e.addr &&
((io.write.bits.mask.asUInt & e.mask.asUInt) =/= 0.U)
))
), "you cannot accumulate to an AccumulatorMem address until previous writes to that address have completed")

when (io.write.bits.acc && isThisBank(io.write.bits.addr) && getBankIdx(io.write.bits.addr) === e.addr) {
rmw_req.ready := false.B
}

when (io.read.req.valid && isThisBank(io.read.req.bits.addr) && getBankIdx(io.read.req.bits.addr) === e.addr) {
reads(1).ready := false.B
when (isThisBank(io.read.req.bits.addr) && getBankIdx(io.read.req.bits.addr) === e.addr) {
only_read_req.ready := false.B
}
}
}

val w_q_head = RegInit(1.U(nEntries.W))
val w_q_tail = RegInit(1.U(nEntries.W))
when (reset.asBool) {
w_q.foreach(_.valid := false.B)
}

val w_q_full = (w_q_tail.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_)
val w_q_empty = !(w_q_head.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_)

val wen = WireInit(false.B)
val wdata = Mux1H(w_q_head.asBools, w_q.map(_.data))
val wmask = Mux1H(w_q_head.asBools, w_q.map(_.mask))
Expand All @@ -158,49 +244,61 @@ class AccumulatorMem[T <: Data, U <: Data](
}
}

when (w_buf_valid && isThisBank(waddr_buf)) {
assert(!((w_q_tail.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_)))
val w_q_push = oldest_pipelined_write.valid && isThisBank(oldest_pipelined_write.bits.addr)

when (w_q_push) {
assert(!w_q_full || wen, "we ran out of acc-sub-bank write q entries")

w_q_tail := (w_q_tail << 1).asUInt() | w_q_tail(nEntries-1)
for (i <- 0 until nEntries) {
when (w_q_tail(i)) {
w_q(i).valid := true.B
w_q(i).data := Mux(acc_buf, w_sum, wdata_buf).asTypeOf(Vec(mask_len, mask_elem))
w_q(i).mask := mask_buf
w_q(i).addr := getBankIdx(waddr_buf)
w_q(i).data := Mux(oldest_pipelined_write.bits.acc, adder_sum, oldest_pipelined_write.bits.data).asTypeOf(Vec(mask_len, mask_elem))
w_q(i).mask := oldest_pipelined_write.bits.mask
w_q(i).addr := getBankIdx(oldest_pipelined_write.bits.addr)
}
}

}
val bank_rdata = mem.read(raddr, ren && !wen).asTypeOf(t)
when (RegNext(ren && reads(0).valid && isThisBank(reads(0).bits))) {
acc_rdata := bank_rdata

val bank_rdata = read(raddr, ren && !wen).asTypeOf(t)
when (RegNext(ren && rmw_req.valid && isThisBank(rmw_req.bits))) {
rdata_for_adder := bank_rdata
} .elsewhen (RegNext(ren)) {
read_rdata := bank_rdata
rdata_for_read_resp := bank_rdata
}

when (wen) {
mem.write(waddr, wdata, wmask)
write(waddr, wdata, wmask)
}

// Three requestors, 1 slot
// Priority is incoming reads for RMW > writes from RMW > incoming reads
when (reads(0).valid && isThisBank(reads(0).bits)) {
// Priority is (in descending order):
// 1. incoming reads for RMW
// 2. writes from RMW
// 3. incoming reads
when (rmw_req.fire() && isThisBank(rmw_req.bits)) {
ren := true.B
when (isThisBank(reads(1).bits)) {
reads(1).ready := false.B
when (isThisBank(only_read_req.bits)) {
only_read_req.ready := false.B
}
} .elsewhen ((w_q_head.asBools zip w_q.map(_.valid)).map({ case (h,v) => h && v }).reduce(_||_)) {
} .elsewhen (!w_q_empty) {
wen := true.B
when (isThisBank(reads(1).bits)) {
reads(1).ready := false.B
when (isThisBank(only_read_req.bits)) {
only_read_req.ready := false.B
}
} .otherwise {
ren := isThisBank(reads(1).bits)
raddr := getBankIdx(reads(1).bits)
ren := isThisBank(only_read_req.bits) && only_read_req.fire()
raddr := getBankIdx(only_read_req.bits)
}

when (reset.asBool) {
w_q.foreach(_.valid := false.B)
}
}
}

val q = Module(new Queue(new AccumulatorReadResp(t, scale_t, log2Ceil(t.head.head.getWidth)), 1, true, true))
q.io.enq.bits.data := read_rdata
q.io.enq.bits.data := rdata_for_read_resp
q.io.enq.bits.scale := RegNext(io.read.req.bits.scale)
q.io.enq.bits.relu6_shift := RegNext(io.read.req.bits.relu6_shift)
q.io.enq.bits.act := RegNext(io.read.req.bits.act)
Expand All @@ -222,17 +320,18 @@ class AccumulatorMem[T <: Data, U <: Data](
val q_will_be_empty = (q.io.count +& q.io.enq.fire()) - q.io.deq.fire() === 0.U
io.read.req.ready := q_will_be_empty && (
// Make sure we aren't accumulating, which would take over both ports
!(io.write.fire() && io.write.bits.acc) &&
// Make sure we aren't reading something that is still being written
!(RegNext(io.write.fire()) && RegNext(io.write.bits.addr) === io.read.req.bits.addr) &&
!(w_buf_valid && waddr_buf === io.read.req.bits.addr) &&
!(io.write.valid && io.write.bits.acc) &&
!pipelined_writes.map(r => r.valid && r.bits.addr === io.read.req.bits.addr).reduce(_||_) &&
!block_read_req
)

io.write.ready := !io.write.bits.acc || (!(io.write.bits.addr === waddr_buf && w_buf_valid) &&
!(io.write.bits.addr === RegNext(io.write.bits.addr) && RegNext(io.write.fire())))
io.write.ready := !block_write_req &&
!pipelined_writes.map(r => r.valid && r.bits.addr === io.write.bits.addr && io.write.bits.acc).reduce(_||_)

when (reset.asBool()) {
pipelined_writes.foreach(_.valid := false.B)
}

// assert(!(io.read.req.valid && io.write.en && io.write.acc), "reading and accumulating simultaneously is not supported")
assert(!(io.read.req.fire() && io.write.fire() && io.read.req.bits.addr === io.write.bits.addr), "reading from and writing to same address is not supported")
assert(!(io.read.req.fire() && w_buf_valid && waddr_buf === io.read.req.bits.addr), "reading from an address immediately after writing to it is not supported")
}
Loading

0 comments on commit 44b28e4

Please sign in to comment.